Skip to content

Commit

Permalink
fix bug in nb_unique_data_seen
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Aug 16, 2024
1 parent 5fddb76 commit f3e5ee4
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions stream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@ def convert_to_full_shape(self, tensor_shape: tuple[int, ...]):
return tensor_shape + (self.full_shape[-1],)

def get_nb_empty_elements(self, slices: tuple[slice, ...]):
"""Returns the number of points for which there are no ComputationNodes."""
assert self.is_valid_shape_dimension(slices), "Last dimension of tensor is reserved for CNs"
all_empty = np.all(self == 0, axis=-1)
return np.sum(all_empty)
all_empty_mask = np.all(self.as_ndarray() == 0, axis=-1)
all_empty_this_slice = all_empty_mask[slices]
return int(np.sum(all_empty_this_slice))

def extend_with_node(self, slices: tuple[slice, ...], node: object):
assert self.is_valid_shape_dimension(slices), "Last dimension of tensor is reserved for CNs"
Expand Down Expand Up @@ -203,12 +205,10 @@ class DiGraphWrapper(Generic[T], DiGraph):
"""Wraps the DiGraph class with type annotations for the nodes"""

@overload
def in_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]:
... # type: ignore
def in_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: ... # type: ignore

@overload
def in_edges(self, node: T) -> list[tuple[T, T]]:
... # type: ignore
def in_edges(self, node: T) -> list[tuple[T, T]]: ... # type: ignore

def in_edges( # type: ignore
self,
Expand All @@ -218,16 +218,13 @@ def in_edges( # type: ignore
return super().in_edges(node, data) # type: ignore

@overload
def out_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]:
... # type: ignore
def out_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]: ... # type: ignore

@overload
def out_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]:
... # type: ignore
def out_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: ... # type: ignore

@overload
def out_edges(self, node: T) -> list[tuple[T, T]]:
... # type: ignore
def out_edges(self, node: T) -> list[tuple[T, T]]: ... # type: ignore

def out_edges( # type: ignore
self,
Expand Down

0 comments on commit f3e5ee4

Please sign in to comment.