diff --git a/alibi/explainers/counterfactual.py b/alibi/explainers/counterfactual.py index dcee0abe0..548474691 100644 --- a/alibi/explainers/counterfactual.py +++ b/alibi/explainers/counterfactual.py @@ -4,35 +4,9 @@ import keras import logging -logger = logging.getLogger(__name__) - - -def cityblock_batch(X: np.ndarray, - y: np.ndarray) -> np.ndarray: - """ - Calculate the L1 distances between a batch of arrays X and an array of the same shape y. - - Parameters - ---------- - X - Batch of arrays to calculate the distances from - y - Array to calculate the distance to - - Returns - ------- - Array of distances from each array in X to y - - """ - X_dim = len(X.shape) - y_dim = len(y.shape) - - if X_dim == y_dim: - assert y.shape[0] == 1, 'y mush have batch size equal to 1' - else: - assert X.shape[1:] == y.shape, 'X and y must have matching shapes' +from alibi.utils.gradients import num_grad_batch - return np.abs(X - y).sum(axis=tuple(np.arange(1, X_dim))).reshape(X.shape[0], -1) +logger = logging.getLogger(__name__) def _define_func(predict_fn: Callable, @@ -84,104 +58,27 @@ def func(X): # type: ignore return func, target_class -def _perturb(X: np.ndarray, - eps: Union[float, np.ndarray] = 1e-08, - proba: bool = False) -> Tuple[np.ndarray, np.ndarray]: - """ - Apply perturbation to instance or prediction probabilities. Used for numerical calculation of gradients. - - Parameters - ---------- - X - Array to be perturbed - eps - Size of perturbation - proba - If True, the net effect of the perturbation needs to be 0 to keep the sum of the probabilities equal to 1 - - Returns - ------- - Instances where a positive and negative perturbation is applied. - """ - # N = batch size; F = nb of features in X - shape = X.shape - X = np.reshape(X, (shape[0], -1)) # NxF - dim = X.shape[1] # F - pert = np.tile(np.eye(dim) * eps, (shape[0], 1)) # (N*F)xF - if proba: - eps_n = eps / (dim - 1) - pert += np.tile((np.eye(dim) - np.ones((dim, dim))) * eps_n, (shape[0], 1)) # (N*F)xF - X_rep = np.repeat(X, dim, axis=0) # (N*F)xF - X_pert_pos, X_pert_neg = X_rep + pert, X_rep - pert - shape = (dim * shape[0],) + shape[1:] - X_pert_pos = np.reshape(X_pert_pos, shape) # (N*F)x(shape of X[0]) - X_pert_neg = np.reshape(X_pert_neg, shape) # (N*F)x(shape of X[0]) - return X_pert_pos, X_pert_neg - - -def num_grad_batch(func: Callable, - X: np.ndarray, - args: Tuple = (), - eps: Union[float, np.ndarray] = 1e-08) -> np.ndarray: - """ - Calculate the numerical gradients of a vector-valued function (typically a prediction function in classification) - with respect to a batch of arrays X. - - Parameters - ---------- - func - Function to be differentiated - X - A batch of vectors at which to evaluate the gradient of the function - args - Any additional arguments to pass to the function - eps - Gradient step to use in the numerical calculation, can be a single float or one for each feature - - Returns - ------- - An array of gradients at each point in the batch X - - """ - # N = gradient batch size; F = nb of features in X, P = nb of prediction classes, B = instance batch size - batch_size = X.shape[0] - data_shape = X[0].shape - preds = func(X, *args) - X_pert_pos, X_pert_neg = _perturb(X, eps) # (N*F)x(shape of X[0]) - X_pert = np.concatenate([X_pert_pos, X_pert_neg], axis=0) - preds_concat = func(X_pert, *args) # make predictions - n_pert = X_pert_pos.shape[0] - - grad_numerator = preds_concat[:n_pert] - preds_concat[n_pert:] # (N*F)*P - grad_numerator = np.reshape(np.reshape(grad_numerator, (batch_size, -1)), - (batch_size, preds.shape[1], -1), order='F') # NxPxF - - grad = grad_numerator / (2 * eps) # NxPxF - grad = grad.reshape(preds.shape + data_shape) # BxPx(shape of X[0]) - - return grad - - class CounterFactual: def __init__(self, sess: tf.Session, - predict_fn: Union[Callable, tf.keras.Model], + predict_fn: Union[Callable, tf.keras.Model, keras.Model], data_shape: Tuple[int, ...], distance_fn: str = 'l1', target_proba: float = 1.0, target_class: Union[str, int] = 'other', max_iter: int = 1000, - lam_init: float = 1e-04, + early_stop: int = 50, + lam_init: float = 1e-1, max_lam_steps: int = 10, tol: float = 0.05, learning_rate_init=0.1, - feature_range: Union[Tuple, str] = (-1e10, 1e10), # important for positive features + feature_range: Union[Tuple, str] = (-1e10, 1e10), eps: Union[float, np.ndarray] = 0.01, # feature-wise epsilons init: str = 'identity', - decay=True, + decay: bool = True, write_dir: str = None, - debug=False) -> None: + debug: bool = False) -> None: """ Initialize counterfactual explanation method based on Wachter et al. (2017) @@ -202,6 +99,8 @@ def __init__(self, desired class membership for the counterfactual instance max_iter Maximum number of interations to run the gradient descent for (inner loop) + early_stop + Number of steps after which to terminate gradient descent if all or none of found instances are solutions lam_init Initial regularization constant for the prediction part of the Wachter loss max_lam_steps @@ -217,7 +116,7 @@ def __init__(self, Gradient step sizes used in calculating numerical gradients, defaults to a single value for all features, but can be passed an array for feature-wise step sizes init - Initialization method for the search of counterfactuals, one of 'random' or 'identity' + Initialization method for the search of counterfactuals, currently must be 'identity' decay Flag to decay learning rate to zero for each outer loop over lambda write_dir @@ -237,6 +136,7 @@ def __init__(self, self.lam_init = lam_init self.tol = tol self.max_lam_steps = max_lam_steps + self.early_stop = early_stop self.eps = eps self.init = init @@ -249,13 +149,13 @@ def __init__(self, self.model = True self.predict_fn = predict_fn.predict # array function self.predict_tn = predict_fn # tensor function - self.n_classes = self.sess.run(self.predict_tn(tf.convert_to_tensor(np.zeros(data_shape), - dtype=tf.float32))).shape[1] + else: # black-box model self.predict_fn = predict_fn self.predict_tn = None self.model = False - self.n_classes = self.predict_fn(np.zeros(data_shape)).shape[1] + + self.n_classes = self.predict_fn(np.zeros(data_shape)).shape[1] # flag to keep track if explainer is fit or not self.fitted = False @@ -270,22 +170,23 @@ def __init__(self, constraint=lambda x: tf.clip_by_value(x, feature_range[0], feature_range[1])) # the following will be a 1-hot encoding of the target class (as predicted by the model) self.target = tf.get_variable('target', shape=(self.batch_size, self.n_classes), dtype=tf.float32) - self.lam = tf.Variable(self.lam_init * np.ones(self.batch_size), name='lambda', dtype=tf.float32) # constant target probability and global step variable self.target_proba = tf.constant(target_proba * np.ones(self.batch_size), dtype=tf.float32, name='target_proba') self.global_step = tf.Variable(0.0, trainable=False, name='global_step') + # lambda hyperparameter - placeholder instead of variable as annealed in first epoch + self.lam = tf.placeholder(tf.float32, shape=(self.batch_size), name='lam') + # define placeholders that will be assigned to relevant variables self.assign_orig = tf.placeholder(tf.float32, data_shape, name='assing_orig') self.assign_cf = tf.placeholder(tf.float32, data_shape, name='assign_cf') self.assign_target = tf.placeholder(tf.float32, shape=(self.batch_size, self.n_classes), name='assign_target') - self.assign_lam = tf.placeholder(tf.float32, shape=(self.batch_size), name='assign_lam') # L1 distance and MAD constants - # TODO: refactor? MADs? + # TODO: MADs? ax_sum = list(np.arange(1, len(self.data_shape))) if distance_fn == 'l1': self.dist = tf.reduce_sum(tf.abs(self.cf - self.orig), axis=ax_sum, name='l1') @@ -309,9 +210,8 @@ def __init__(self, self.pred_proba_class = tf.reduce_max(self.target * self.pred_proba, 1) elif target_class == 'other': self.pred_proba_class = tf.reduce_max((1 - self.target) * self.pred_proba, 1) - elif isinstance(target_class, int): + elif target_class in range(self.n_classes): # if class is specified, this is known in advance - # TODO: try/except to handle invalid cases self.pred_proba_class = tf.reduce_max(tf.one_hot(target_class, self.n_classes, dtype=tf.float32) * self.pred_proba, 1) else: @@ -325,9 +225,9 @@ def __init__(self, # optimizer if decay: self.learning_rate = tf.train.polynomial_decay(learning_rate_init, self.global_step, - self.max_iter, 0.0, power=0.5) + self.max_iter, 0.0, power=1.0) else: - self.learning_rate = learning_rate_init + self.learning_rate = tf.convert_to_tensor(learning_rate_init) # TODO optional argument to change type, learning rate scheduler opt = tf.train.AdamOptimizer(self.learning_rate) @@ -343,7 +243,6 @@ def __init__(self, self.setup.append(self.orig.assign(self.assign_orig)) self.setup.append(self.cf.assign(self.assign_cf)) self.setup.append(self.target.assign(self.assign_target)) - self.setup.append(self.lam.assign(self.assign_lam)) self.tf_init = tf.variables_initializer(var_list=tf.global_variables(scope='cf_search')) @@ -352,18 +251,18 @@ def __init__(self, self.writer = tf.summary.FileWriter(write_dir, tf.get_default_graph()) self.writer.add_graph(tf.get_default_graph()) + # return templates + self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'pred_class', 'prob', 'loss']) + self.return_dict = {'cf': None, 'all': [], 'orig_class': None, 'orig_prob': None} # type: dict + def _initialize(self, X: np.ndarray) -> np.ndarray: # TODO initialization strategies ("same", "random", "from_train") if self.init == 'identity': X_init = X logger.debug('Initializing search at the test point X') - elif self.init == 'random': - # TODO: handle ranges - X_init = np.random.rand(*self.data_shape) - logger.debug('Initializing search at a random test point') else: - raise ValueError('Initialization method should be one of "random" or "identity"') + raise ValueError('Initialization method should be "identity"') return X_init @@ -382,71 +281,163 @@ def explain(self, X: np.ndarray) -> Dict: 'but first dim = %s', X.shape[0]) # make a prediction - if self.model: - Y = self.sess.run(self.predict_tn(tf.convert_to_tensor(X, dtype=tf.float32))) - else: - Y = self.predict_fn(X) + Y = self.predict_fn(X) + + pred_class = Y.argmax(axis=1).item() + pred_prob = Y.max(axis=1).item() + self.return_dict['orig_class'] = pred_class + self.return_dict['orig_prob'] = pred_prob - pred_class = Y.argmax() - logger.debug('Initial prediction: %s with p=%s', pred_class, Y.max()) + logger.debug('Initial prediction: %s with p=%s', pred_class, pred_prob) # define the class-specific prediction function self.predict_class_fn, t_class = _define_func(self.predict_fn, pred_class, self.target_class) - #if not self.fitted: - # logger.warning('Explain called before fit, explainer will operate in unsupervised mode.') - # initialize with an instance X_init = self._initialize(X) # minimize loss iteratively - exp_dict = self._minimize_loss(X, X_init, Y) + self._minimize_loss(X, X_init, Y) + + return_dict = self.return_dict.copy() + self.return_dict = {'cf': None, 'all': [], 'orig_class': None, 'orig_prob': None} - return exp_dict + return return_dict def _prob_condition(self, X_current): return np.abs(self.predict_class_fn(X_current) - self.target_proba_arr) <= self.tol + def _update_exp(self, i, l_step, lam, cf_found, X_current): + cf_found[0][l_step] += 1 # TODO: batch support + dist = self.sess.run(self.dist).item() + + # populate the return dict + self.instance_dict['X'] = X_current + self.instance_dict['distance'] = dist + self.instance_dict['lambda'] = lam[0] + self.instance_dict['index'] = l_step * self.max_iter + i + + preds = self.predict_fn(X_current) + pred_class = preds.argmax() + prob = preds.max() + self.instance_dict['pred_class'] = pred_class + self.instance_dict['prob'] = prob + + self.instance_dict['loss'] = (self.instance_dict['prob'] - self.target_proba_arr[0]) ** 2 + \ + self.instance_dict['lambda'] * self.instance_dict['distance'] + + self.return_dict['all'].append(self.instance_dict.copy()) + + # update best CF if it has a smaller distance + if self.return_dict['cf'] is None: + self.return_dict['cf'] = self.instance_dict.copy() + + elif dist < self.return_dict['cf']['distance']: + self.return_dict['cf'] = self.instance_dict.copy() + + logger.debug('CF found at step %s', l_step * self.max_iter + i) + + def _write_tb(self, lam, lam_lb, lam_ub, cf_found, X_current, **kwargs): + if self.model: + scalars_tf = [self.global_step, self.learning_rate, self.dist[0], + self.loss_pred[0], self.loss_opt[0], self.pred_proba_class[0]] + gs, lr, dist, loss_pred, loss_opt, pred = self.sess.run(scalars_tf, feed_dict={self.lam: lam}) + else: + scalars_tf = [self.global_step, self.learning_rate, self.dist[0], + self.loss_opt[0]] + gs, lr, dist, loss_opt = self.sess.run(scalars_tf, feed_dict={self.lam: lam}) + loss_pred = kwargs['loss_pred'] + pred = kwargs['pred'] + + try: + found = kwargs['found'] + not_found = kwargs['not_found'] + except KeyError: + found = 0 + not_found = 0 + + summary = tf.Summary() + summary.value.add(tag='lr/global_step', simple_value=gs) + summary.value.add(tag='lr/lr', simple_value=lr) + + summary.value.add(tag='lambda/lambda', simple_value=lam[0]) + summary.value.add(tag='lambda/l_bound', simple_value=lam_lb[0]) + summary.value.add(tag='lambda/u_bound', simple_value=lam_ub[0]) + + summary.value.add(tag='losses/dist', simple_value=dist) + summary.value.add(tag='losses/loss_pred', simple_value=loss_pred) + summary.value.add(tag='losses/loss_opt', simple_value=loss_opt) + summary.value.add(tag='losses/pred_div_dist', simple_value=loss_pred / (lam[0] * dist)) + + summary.value.add(tag='Y/pred_proba_class', simple_value=pred) + summary.value.add(tag='Y/pred_class_fn(X_current)', simple_value=self.predict_class_fn(X_current)) + summary.value.add(tag='Y/n_cf_found', simple_value=cf_found[0].sum()) + summary.value.add(tag='Y/found', simple_value=found) + summary.value.add(tag='Y/not_found', simple_value=not_found) + + self.writer.add_summary(summary) + self.writer.flush() + + def _bisect_lambda(self, cf_found, l_step, lam, lam_lb, lam_ub): + + for batch_idx in range(self.batch_size): # TODO: batch not supported + if cf_found[batch_idx][l_step] >= 5: # minimum number of CF instances to warrant increasing lambda + # want to improve the solution by putting more weight on the distance term TODO: hyperparameter? + # by increasing lambda + lam_lb[batch_idx] = max(lam[batch_idx], lam_lb[batch_idx]) + logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx]) + if lam_ub[batch_idx] < 1e9: + lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2 + else: + lam[batch_idx] *= 10 + logger.debug('Changed lambda to %s', lam[batch_idx]) + + elif cf_found[batch_idx][l_step] < 5: + # if not enough solutions found so far, decrease lambda by a factor of 10, + # otherwise bisect up to the last known successful lambda + lam_ub[batch_idx] = min(lam_ub[batch_idx], lam[batch_idx]) + logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx]) + if lam_lb[batch_idx] > 0: + lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2 + logger.debug('Changed lambda to %s', lam[batch_idx]) + else: + lam[batch_idx] /= 10 + + return lam, lam_lb, lam_ub + def _minimize_loss(self, X: np.ndarray, X_init: np.ndarray, - Y: np.ndarray) -> Dict: - - # keep track of found CFs for each lambda in outer loop - cf_found = np.zeros((self.batch_size, self.max_lam_steps), dtype=bool) + Y: np.ndarray) -> None: - # returned explanation as the best counterfactual+metrics and all the other samples found on the way that - # satisfy the probability constraint - return_dict = {'cf': None, 'all': [], 'orig_class': Y.argmax(),'orig_prob': Y.max()} - instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'pred_class', 'prob', 'loss']) + # keep track of the number of CFs found for each lambda in outer loop + cf_found = np.zeros((self.batch_size, self.max_lam_steps)) # set the lower and upper bound for lamda to scale the distance loss term - lam = np.ones(self.batch_size) * self.lam_init lam_lb = np.zeros(self.batch_size) lam_ub = np.ones(self.batch_size) * 1e10 - lam_steps = 0 - X_current = X_init - # make a one-hot vector of targets Y_ohe = np.zeros(Y.shape) np.put(Y_ohe, np.argmax(Y, axis=1), 1) - for l_step in range(self.max_lam_steps): - self.sess.run(self.tf_init) - lr = self.sess.run(self.learning_rate) - logger.debug('Starting outer loop: %s/%s with lambda=%s, lr=%s', lam_steps + 1, self.max_lam_steps, lam, lr) + # on first run estimate lambda bounds + n_orders = 10 + n_steps = self.max_iter // n_orders + lams = np.array([self.lam_init / 10 ** i for i in range(n_orders)]) # exponential decay + cf_count = np.zeros_like(lams) + logger.debug('Initial lambda sweep: %s', lams) - # assign variables for the current iteration + X_current = X_init + # TODO this whole initial loop should be optional? + for ix, l_step in enumerate(lams): + lam = np.ones(self.batch_size) * l_step + self.sess.run(self.tf_init) self.sess.run(self.setup, {self.assign_orig: X, self.assign_cf: X_current, - self.assign_target: Y_ohe, - self.assign_lam: lam}) + self.assign_target: Y_ohe}) - num_iter = 0 - - # number of gradient descent steps in each inner loop - for i in range(self.max_iter): + for i in range(n_steps): # numerical gradients grads_num = np.zeros(self.data_shape) @@ -462,96 +453,106 @@ def _minimize_loss(self, # add values to tensorboard (1st item in batch only) every n steps if self.debug and not i % 50: - if self.model: - scalars_tf = [self.global_step, self.learning_rate, self.lam[0], self.dist[0], - self.loss_pred[0], self.loss_opt[0], self.pred_proba_class[0]] - gs, lr, lm, dist, loss_pred, loss_opt, pred = self.sess.run(scalars_tf) + if not self.model: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, loss_pred=loss_pred, pred=pred) else: - scalars_tf = [self.global_step, self.learning_rate, self.lam[0], self.dist[0], - self.loss_opt[0]] - gs, lr, lm, dist, loss_opt = self.sess.run(scalars_tf) + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current) - summary = tf.Summary() - summary.value.add(tag='lr/global_step', simple_value=gs) - summary.value.add(tag='lr/lr', simple_value=lr) + # compute graph gradients + grads_vars_graph = self.sess.run(self.compute_grads, feed_dict={self.lam: lam}) + grads_graph = [g for g, _ in grads_vars_graph][0] - summary.value.add(tag='lambda/lambda', simple_value=lm) - summary.value.add(tag='lambda/l_bound', simple_value=lam_lb[0]) - summary.value.add(tag='lambda/u_bound', simple_value=lam_ub[0]) + # apply gradients + gradients = grads_graph + grads_num + self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients, self.lam: lam}) - summary.value.add(tag='losses/dist', simple_value=dist) - summary.value.add(tag='losses/loss_pred', simple_value=loss_pred) - summary.value.add(tag='losses/loss_opt', simple_value=loss_opt) + # does the counterfactual condition hold? + X_current = self.sess.run(self.cf) + cond = self._prob_condition(X_current).squeeze() + if cond: + cf_count[ix] += 1 + + # find the lower bound + logger.debug('cf_count: %s', cf_count) + try: + lb_ix = np.where(cf_count > 0)[0][1] # take the second order of magnitude with some CFs as lower-bound + # TODO robust? + except IndexError: + logger.exception('No appropriate lambda range found, try decreasing lam_init') + lam_lb = np.ones(self.batch_size) * lams[lb_ix] + + # find the upper bound + try: + ub_ix = np.where(cf_count == 0)[0][-1] # TODO is 0 robust? + except IndexError: + ub_ix = 0 + logger.exception('Could not find upper bound for lambda where no solutions found, setting upper bound to ' + 'lam_init=%s', lams[ub_ix]) + lam_ub = np.ones(self.batch_size) * lams[ub_ix] + + # start the search in the middle + lam = (lam_lb + lam_ub) / 2 + + logger.debug('Found upper and lower bounds: %s, %s', lam_lb[0], lam_ub[0]) + + # on subsequent runs bisect lambda within the bounds found initially + X_current = X_init + for l_step in range(self.max_lam_steps): + self.sess.run(self.tf_init) - summary.value.add(tag='Y/pred_proba_class', simple_value=pred) - summary.value.add(tag='Y/pred_class_fn(X_current)', simple_value=self.predict_class_fn(X_current)) + # assign variables for the current iteration + self.sess.run(self.setup, {self.assign_orig: X, + self.assign_cf: X_current, + self.assign_target: Y_ohe}) - self.writer.add_summary(summary) - self.writer.flush() + found, not_found = 0, 0 + # number of gradient descent steps in each inner loop + for i in range(self.max_iter): + + # numerical gradients + grads_num = np.zeros(self.data_shape) + if not self.model: + pred = self.predict_class_fn(X_current) + prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps) + + # squared difference prediction loss + loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2 + grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad + + grads_num = grads_num.reshape(self.data_shape) + + # add values to tensorboard (1st item in batch only) every n steps + if self.debug and not i % 50: + if not self.model: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found, + loss_pred=loss_pred, pred=pred) + else: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found) - # minimize the loss - num_iter += 1 # compute graph gradients - grads_vars_graph = self.sess.run(self.compute_grads) + grads_vars_graph = self.sess.run(self.compute_grads, feed_dict={self.lam: lam}) grads_graph = [g for g, _ in grads_vars_graph][0] # apply gradients gradients = grads_graph + grads_num - self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients}) + self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients, self.lam: lam}) # does the counterfactual condition hold? - if not i % 10: - X_current = self.sess.run(self.cf) - cond = self._prob_condition(X_current).squeeze() - if cond: - cf_found[0][l_step] = True # TODO: batch support - - # populate the return dict - instance_dict['X'] = X_current - instance_dict['distance'] = self.sess.run(self.dist).item() - instance_dict['lambda'] = lam[0] - instance_dict['index'] = l_step * self.max_iter + i - - preds = self.predict_fn(X_current) - pred_class = preds.argmax() - prob = preds.max() - instance_dict['pred_class'] = pred_class - instance_dict['prob'] = prob - - instance_dict['loss'] = (instance_dict['prob'] - self.target_proba_arr[0]) ** 2 + instance_dict[ - 'lambda'] * instance_dict['distance'] - - return_dict['cf'] = instance_dict.copy() - return_dict['all'].append(instance_dict.copy()) - - logger.debug('CF found') - - # adjust the lambda constant via bisection - for batch_idx in range(self.batch_size): # TODO: batch not supported - if cf_found[batch_idx][l_step]: - # want to improve the solution by putting more weight on the distance term - # by increasing lambda - lam_lb[batch_idx] = max(lam[batch_idx], lam_lb[batch_idx]) - logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx]) - if lam_ub[batch_idx] < 1e9: - lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2 - else: - lam[batch_idx] *= 10 - logger.debug('Changed lambda to %s', lam[batch_idx]) - - elif not cf_found[batch_idx][l_step]: - # if no solution found so far, decrease lambda by a factor of 10, - # otherwise bisect up to the last known successful lambda - lam_ub[batch_idx] = min(lam_ub[batch_idx], lam[batch_idx]) - logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx]) - if lam_lb[batch_idx] > 0: - lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2 - logger.debug('Changed lambda to %s', lam[batch_idx]) - else: - lam[batch_idx] /= 10 + X_current = self.sess.run(self.cf) + cond = self._prob_condition(X_current) + if cond: + self._update_exp(i, l_step, lam, cf_found, X_current) + found += 1 + not_found = 0 + else: + found = 0 + not_found += 1 - lam_steps += 1 + # early stopping criterion - if no solutions or enough solutions found, change lambda + if found >= self.early_stop or not_found >= self.early_stop: + break - return_dict['success'] = True + # adjust the lambda constant via bisection at the end of the outer loop + self._bisect_lambda(cf_found, l_step, lam, lam_lb, lam_ub) - return return_dict + self.return_dict['success'] = True diff --git a/alibi/explainers/tests/test_counterfactual.py b/alibi/explainers/tests/test_counterfactual.py index 71c427d6f..3c9323e62 100644 --- a/alibi/explainers/tests/test_counterfactual.py +++ b/alibi/explainers/tests/test_counterfactual.py @@ -9,7 +9,7 @@ from tensorflow.keras.layers import Dense import tensorflow.keras.backend as K -from alibi.explainers.counterfactual import _define_func, num_grad_batch, cityblock_batch +from alibi.explainers.counterfactual import _define_func from alibi.explainers import CounterFactual @@ -20,7 +20,7 @@ def logistic_iris(): return X, y, lr -@pytest.fixture() +@pytest.fixture def tf_keras_logistic_mnist(): (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() input_dim = 784 @@ -53,8 +53,8 @@ def iris_explainer(request, logistic_iris): predict_fn = lr.predict_proba sess = tf.Session() cf_explainer = CounterFactual(sess=sess, predict_fn=predict_fn, data_shape=(1, 4), - target_class=request.param, lam_init=1e-4, max_iter=1000, - max_lam_steps=5) + target_class=request.param, lam_init=1e-1, max_iter=1000, + max_lam_steps=10) yield cf_explainer tf.reset_default_graph() @@ -67,11 +67,10 @@ def tf_keras_mnist_explainer(request, tf_keras_logistic_mnist): sess = K.get_session() cf_explainer = CounterFactual(sess=sess, predict_fn=model, data_shape=(1, 784), - target_class=request.param, lam_init=1e-4, max_iter=1000, - max_lam_steps=5) + target_class=request.param, lam_init=1e-1, max_iter=1000, + max_lam_steps=10) yield cf_explainer - @pytest.mark.parametrize('target_class', ['other', 'same', 0, 1, 2]) def test_define_func(logistic_iris, target_class): X, y, model = logistic_iris @@ -97,43 +96,7 @@ def test_define_func(logistic_iris, target_class): assert func(x) == probas[:, ix2] -@pytest.mark.parametrize('shape', [(1,), (2, 3), (1, 3, 5)]) -@pytest.mark.parametrize('batch_size', [1, 3, 10]) -def test_get_batch_num_gradients_cityblock(shape, batch_size): - u = np.random.rand(batch_size, *shape) - v = np.random.rand(1, *shape) - - grad_true = np.sign(u - v).reshape(batch_size, 1, *shape) # expand dims to incorporate 1-d scalar response - grad_approx = num_grad_batch(cityblock_batch, u, args=tuple([v])) - - assert grad_approx.shape == grad_true.shape - assert np.allclose(grad_true, grad_approx) - - -@pytest.mark.parametrize('batch_size', [1, 2, 5]) -def test_get_batch_num_gradients_logistic_iris(logistic_iris, batch_size): - X, y, lr = logistic_iris - predict_fn = lr.predict_proba - x = X[0:batch_size] - probas = predict_fn(x) - - # true gradient of the logistic regression wrt x - grad_true = np.zeros((batch_size, 3, 4)) - for i, p in enumerate(probas): - p = p.reshape(1, 3) - grad = (p.T * (np.eye(3, 3) - p) @ lr.coef_) - grad_true[i, :, :] = grad - assert grad_true.shape == (batch_size, 3, 4) - - grad_approx = num_grad_batch(predict_fn, x) - - assert grad_approx.shape == grad_true.shape - assert np.allclose(grad_true, grad_approx) - - -@pytest.mark.parametrize('iris_explainer', - ['other', 'same', 0, 1, 2], - indirect=True) +@pytest.mark.parametrize('iris_explainer', ['other', 'same', 0, 1, 2], indirect=True) def test_cf_explainer_iris(iris_explainer, logistic_iris): X, y, lr = logistic_iris x = X[0].reshape(1, -1) @@ -168,9 +131,7 @@ def test_cf_explainer_iris(iris_explainer, logistic_iris): assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol -@pytest.mark.parametrize('tf_keras_mnist_explainer', - ['other', 'same', 9], - indirect=True) +@pytest.mark.parametrize('tf_keras_mnist_explainer', ['other', 'same', 4, 9], indirect=True) def test_tf_keras_mnist_explainer(tf_keras_mnist_explainer, tf_keras_logistic_mnist): X, y, model = tf_keras_logistic_mnist x = X[0].reshape(1, -1) diff --git a/alibi/utils/distance.py b/alibi/utils/distance.py new file mode 100644 index 000000000..2d65874ec --- /dev/null +++ b/alibi/utils/distance.py @@ -0,0 +1,29 @@ +import numpy as np + + +def cityblock_batch(X: np.ndarray, + y: np.ndarray) -> np.ndarray: + """ + Calculate the L1 distances between a batch of arrays X and an array of the same shape y. + + Parameters + ---------- + X + Batch of arrays to calculate the distances from + y + Array to calculate the distance to + + Returns + ------- + Array of distances from each array in X to y + + """ + X_dim = len(X.shape) + y_dim = len(y.shape) + + if X_dim == y_dim: + assert y.shape[0] == 1, 'y must have batch size equal to 1' + else: + assert X.shape[1:] == y.shape, 'X and y must have matching shapes' + + return np.abs(X - y).sum(axis=tuple(np.arange(1, X_dim))).reshape(X.shape[0], -1) diff --git a/alibi/utils/gradients.py b/alibi/utils/gradients.py new file mode 100644 index 000000000..32aa6a9c2 --- /dev/null +++ b/alibi/utils/gradients.py @@ -0,0 +1,80 @@ +from typing import Union, Tuple, Callable +import numpy as np + + +def perturb(X: np.ndarray, + eps: Union[float, np.ndarray] = 1e-08, + proba: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Apply perturbation to instance or prediction probabilities. Used for numerical calculation of gradients. + + Parameters + ---------- + X + Array to be perturbed + eps + Size of perturbation + proba + If True, the net effect of the perturbation needs to be 0 to keep the sum of the probabilities equal to 1 + + Returns + ------- + Instances where a positive and negative perturbation is applied. + """ + # N = batch size; F = nb of features in X + shape = X.shape + X = np.reshape(X, (shape[0], -1)) # NxF + dim = X.shape[1] # F + pert = np.tile(np.eye(dim) * eps, (shape[0], 1)) # (N*F)xF + if proba: + eps_n = eps / (dim - 1) + pert += np.tile((np.eye(dim) - np.ones((dim, dim))) * eps_n, (shape[0], 1)) # (N*F)xF + X_rep = np.repeat(X, dim, axis=0) # (N*F)xF + X_pert_pos, X_pert_neg = X_rep + pert, X_rep - pert + shape = (dim * shape[0],) + shape[1:] + X_pert_pos = np.reshape(X_pert_pos, shape) # (N*F)x(shape of X[0]) + X_pert_neg = np.reshape(X_pert_neg, shape) # (N*F)x(shape of X[0]) + return X_pert_pos, X_pert_neg + + +def num_grad_batch(func: Callable, + X: np.ndarray, + args: Tuple = (), + eps: Union[float, np.ndarray] = 1e-08) -> np.ndarray: + """ + Calculate the numerical gradients of a vector-valued function (typically a prediction function in classification) + with respect to a batch of arrays X. + + Parameters + ---------- + func + Function to be differentiated + X + A batch of vectors at which to evaluate the gradient of the function + args + Any additional arguments to pass to the function + eps + Gradient step to use in the numerical calculation, can be a single float or one for each feature + + Returns + ------- + An array of gradients at each point in the batch X + + """ + # N = gradient batch size; F = nb of features in X, P = nb of prediction classes, B = instance batch size + batch_size = X.shape[0] + data_shape = X[0].shape + preds = func(X, *args) + X_pert_pos, X_pert_neg = perturb(X, eps) # (N*F)x(shape of X[0]) + X_pert = np.concatenate([X_pert_pos, X_pert_neg], axis=0) + preds_concat = func(X_pert, *args) # make predictions + n_pert = X_pert_pos.shape[0] + + grad_numerator = preds_concat[:n_pert] - preds_concat[n_pert:] # (N*F)*P + grad_numerator = np.reshape(np.reshape(grad_numerator, (batch_size, -1)), + (batch_size, preds.shape[1], -1), order='F') # NxPxF + + grad = grad_numerator / (2 * eps) # NxPxF + grad = grad.reshape(preds.shape + data_shape) # BxPx(shape of X[0]) + + return grad diff --git a/alibi/utils/tests/test_distance.py b/alibi/utils/tests/test_distance.py new file mode 100644 index 000000000..02d0db516 --- /dev/null +++ b/alibi/utils/tests/test_distance.py @@ -0,0 +1,25 @@ +import numpy as np +from scipy.spatial.distance import cityblock +from itertools import product +import pytest +from alibi.utils.distance import cityblock_batch + +dims = np.array([1, 10, 50]) +shapes = list(product(dims, dims)) +n_tests = len(dims)**2 + +@pytest.fixture +def random_matrix(request): + shape = shapes[request.param] + matrix = np.random.rand(*shape) + return matrix + +@pytest.mark.parametrize('random_matrix', list(range(n_tests)), indirect=True) +def test_cityblock_batch(random_matrix): + X = random_matrix + y = X[np.random.choice(X.shape[0])] + + batch_dists = cityblock_batch(X, y) + single_dists = np.array([cityblock(x, y) for x in X]).reshape(X.shape[0], -1) + + assert np.allclose(batch_dists, single_dists) \ No newline at end of file diff --git a/alibi/utils/tests/test_gradients.py b/alibi/utils/tests/test_gradients.py new file mode 100644 index 000000000..c8621bfde --- /dev/null +++ b/alibi/utils/tests/test_gradients.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from alibi.utils.distance import cityblock_batch +from alibi.utils.gradients import num_grad_batch + +@pytest.fixture +def logistic_iris(): + X, y = load_iris(return_X_y=True) + lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y) + return X, y, lr + +@pytest.mark.parametrize('shape', [(1,), (2, 3), (1, 3, 5)]) +@pytest.mark.parametrize('batch_size', [1, 3, 10]) +def test_get_batch_num_gradients_cityblock(shape, batch_size): + u = np.random.rand(batch_size, *shape) + v = np.random.rand(1, *shape) + + grad_true = np.sign(u - v).reshape(batch_size, 1, *shape) # expand dims to incorporate 1-d scalar response + grad_approx = num_grad_batch(cityblock_batch, u, args=tuple([v])) + + assert grad_approx.shape == grad_true.shape + assert np.allclose(grad_true, grad_approx) + + +@pytest.mark.parametrize('batch_size', [1, 2, 5]) +def test_get_batch_num_gradients_logistic_iris(logistic_iris, batch_size): + X, y, lr = logistic_iris + predict_fn = lr.predict_proba + x = X[0:batch_size] + probas = predict_fn(x) + + # true gradient of the logistic regression wrt x + grad_true = np.zeros((batch_size, 3, 4)) + for i, p in enumerate(probas): + p = p.reshape(1, 3) + grad = (p.T * (np.eye(3, 3) - p) @ lr.coef_) + grad_true[i, :, :] = grad + assert grad_true.shape == (batch_size, 3, 4) + + grad_approx = num_grad_batch(predict_fn, x) + + assert grad_approx.shape == grad_true.shape + assert np.allclose(grad_true, grad_approx) \ No newline at end of file