Skip to content

Commit

Permalink
Merge pull request #64 from fzi-forschungszentrum-informatik/ng_checks
Browse files Browse the repository at this point in the history
Ng checks
  • Loading branch information
JHoelli authored Apr 3, 2024
2 parents 850407b + 4c1fdd5 commit 228374d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 16 deletions.
Binary file modified ClassificationModels/models/ElectricDevices/OneHotEncoder.pkl
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _native_guide_wrapper(self, query, predicted_label, distance, n_neighbors):
def _findSubarray(
self, a, k
): # used to find the maximum contigious subarray of length k in the explanation weight vector

if len(a.shape) == 2:
a = a.reshape(-1)
n = len(a)
Expand All @@ -178,6 +179,8 @@ def _findSubarray(
# Store the sub-array elements in the array
for j in range(i, i + k):
temp.append(a[j])



# Push the vector in the container
vec.append(temp)
Expand Down Expand Up @@ -217,6 +220,8 @@ def _counterfactual_generator_swap(

if np.any(np.isnan(most_influencial_array)):
return np.full(individual.shape, None), None
if len(training_weights)==1:
training_weights=training_weights[0]

starting_point = np.where(training_weights == most_influencial_array[0])[0][0]

Expand All @@ -240,6 +245,8 @@ def _counterfactual_generator_swap(
most_influencial_array = self._findSubarray(
(training_weights), subarray_length
)
if len(training_weights)==1:
training_weights=training_weights[0]
starting_point = np.where(training_weights == most_influencial_array[0])[0][
0
]
Expand Down
2 changes: 1 addition & 1 deletion TSInterpret/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION = (0, 4, 4)
VERSION = (0, 4, 5)
__version__ = ".".join(map(str, VERSION)) # noqa: F401
44 changes: 30 additions & 14 deletions docs/Notebooks/NunCF_torch.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"partd==1.2.0",
"pytz>=2021.3",
"shap>=0.39.0,< 1.0",
"tensorflow>=2.9.1,< 2.15.0",
"tensorflow>=2.9.1,< 2.14.1",
"keras>=2.9.0,< 3.0",
"tsfresh>=0.18.0,< 1.0",
"tslearn>= 0.5.2,< 1.0",
Expand Down

0 comments on commit 228374d

Please sign in to comment.