From f3e5ee417cc5ec0f98697afdb5f4e57f279cc33e Mon Sep 17 00:00:00 2001 From: RobinGeens Date: Fri, 16 Aug 2024 14:12:01 +0200 Subject: [PATCH] fix bug in nb_unique_data_seen --- stream/utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/stream/utils.py b/stream/utils.py index 2372c74..6bd504c 100644 --- a/stream/utils.py +++ b/stream/utils.py @@ -156,9 +156,11 @@ def convert_to_full_shape(self, tensor_shape: tuple[int, ...]): return tensor_shape + (self.full_shape[-1],) def get_nb_empty_elements(self, slices: tuple[slice, ...]): + """Returns the number of points for which there are no ComputationNodes.""" assert self.is_valid_shape_dimension(slices), "Last dimension of tensor is reserved for CNs" - all_empty = np.all(self == 0, axis=-1) - return np.sum(all_empty) + all_empty_mask = np.all(self.as_ndarray() == 0, axis=-1) + all_empty_this_slice = all_empty_mask[slices] + return int(np.sum(all_empty_this_slice)) def extend_with_node(self, slices: tuple[slice, ...], node: object): assert self.is_valid_shape_dimension(slices), "Last dimension of tensor is reserved for CNs" @@ -203,12 +205,10 @@ class DiGraphWrapper(Generic[T], DiGraph): """Wraps the DiGraph class with type annotations for the nodes""" @overload - def in_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: - ... # type: ignore + def in_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: ... # type: ignore @overload - def in_edges(self, node: T) -> list[tuple[T, T]]: - ... # type: ignore + def in_edges(self, node: T) -> list[tuple[T, T]]: ... # type: ignore def in_edges( # type: ignore self, @@ -218,16 +218,13 @@ def in_edges( # type: ignore return super().in_edges(node, data) # type: ignore @overload - def out_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]: - ... # type: ignore + def out_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]: ... # type: ignore @overload - def out_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: - ... # type: ignore + def out_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: ... # type: ignore @overload - def out_edges(self, node: T) -> list[tuple[T, T]]: - ... # type: ignore + def out_edges(self, node: T) -> list[tuple[T, T]]: ... # type: ignore def out_edges( # type: ignore self,