diff --git a/alibi_detect/cd/pytorch/context_aware.py b/alibi_detect/cd/pytorch/context_aware.py index d068a182c..cda6c12fe 100644 --- a/alibi_detect/cd/pytorch/context_aware.py +++ b/alibi_detect/cd/pytorch/context_aware.py @@ -230,8 +230,8 @@ def _pick_lam(self, lams: torch.Tensor, K: torch.Tensor, L: torch.Tensor, n_fold K, L = K[perm][:, perm], L[perm][:, perm] losses = torch.zeros_like(lams, dtype=torch.float).to(K.device) for fold in range(n_folds): - inds_oof = np.arange(n)[(fold*fold_size):((fold+1)*fold_size)] - inds_if = np.setdiff1d(np.arange(n), inds_oof) + inds_oof = list(np.arange(n)[(fold*fold_size):((fold+1)*fold_size)]) + inds_if = list(np.setdiff1d(np.arange(n), inds_oof)) K_if, L_if = K[inds_if][:, inds_if], L[inds_if][:, inds_if] n_if = len(K_if) L_inv_lams = torch.stack(