From 5d10c39c6d7c127ae9affcaccd17c0cff31dbc62 Mon Sep 17 00:00:00 2001 From: manuba95 Date: Thu, 27 Jul 2023 14:56:50 +0200 Subject: [PATCH 1/2] fix: nanargmin with all nan slice --- floodlight/models/space.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/floodlight/models/space.py b/floodlight/models/space.py index 7843abca..9f285180 100644 --- a/floodlight/models/space.py +++ b/floodlight/models/space.py @@ -200,7 +200,11 @@ def _calc_cell_controls(self, xy1: XY, xy2: XY): # calculate pairwise distances and determine closest player pairwise_distances = cdist(mesh_points, player_points) - closest_player_index = np.nanargmin(pairwise_distances, axis=1) + closest_player_index = np.where( + np.isnan(pairwise_distances).all(axis=1), + np.NaN, + np.nanargmin(pairwise_distances, axis=1) + ) self._cell_controls_[t] = closest_player_index.reshape(self._meshx_.shape) def fit(self, xy1: XY, xy2: XY): From 277fa2499a5517b327e5e08e362edc6a3e94c63d Mon Sep 17 00:00:00 2001 From: manuba95 Date: Tue, 21 May 2024 16:16:10 +0200 Subject: [PATCH 2/2] replace np.where() --- floodlight/models/space.py | 49 ++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/floodlight/models/space.py b/floodlight/models/space.py index 9f285180..4977c0d6 100644 --- a/floodlight/models/space.py +++ b/floodlight/models/space.py @@ -192,19 +192,26 @@ def _calc_cell_controls(self, xy1: XY, xy2: XY): np.nan, ) + # stack and reshape mesh coordinates to (M x 2) array + mesh_points = np.stack((self._meshx_, self._meshy_), axis=2).reshape(-1, 2) + # loop for t in range(T): - # stack and reshape player and mesh coordinates to (M x 2) arrays + # stack and reshape player coordinates to (M x 2) array player_points = np.hstack((xy1.frame(t), xy2.frame(t))).reshape(-1, 2) - mesh_points = np.stack((self._meshx_, self._meshy_), axis=2).reshape(-1, 2) # calculate pairwise distances and determine closest player pairwise_distances = cdist(mesh_points, player_points) - closest_player_index = np.where( - np.isnan(pairwise_distances).all(axis=1), - np.NaN, - np.nanargmin(pairwise_distances, axis=1) - ) + + # identify valid segments without all-NaN slices + all_nan_mask = np.isnan(pairwise_distances).all(axis=1) + valid_mask = ~all_nan_mask + + # Init closest player index array + closest_player_index = np.full(pairwise_distances.shape[0], np.NaN) + + if np.any(valid_mask): + closest_player_index = np.nanargmin(pairwise_distances, axis=1) self._cell_controls_[t] = closest_player_index.reshape(self._meshx_.shape) def fit(self, xy1: XY, xy2: XY): @@ -356,20 +363,26 @@ def plot( .. image:: ../../_img/sample_dvm_plot_hex.png """ - # get ax - ax = ax or plt.subplots()[1] - # get colors and construct team color vector - team_color1, team_color2 = team_colors - color_vector = [team_color1] * self._N1_ + [team_color2] * self._N2_ + # check if t refers to an all-nan slice in the cell controlls + if np.isnan(self._cell_controls_[t]).all(): + pass + else: - # call plot by mesh type - if self._mesh_type == "square": - ax = self._plot_square(t, color_vector, ax=ax, **kwargs) - elif self._mesh_type == "hexagonal": - ax = self._plot_hexagonal(t, color_vector, ax=ax, **kwargs) + # get ax + ax = ax or plt.subplots()[1] - return ax + # get colors and construct team color vector + team_color1, team_color2 = team_colors + color_vector = [team_color1] * self._N1_ + [team_color2] * self._N2_ + + # call plot by mesh type + if self._mesh_type == "square": + ax = self._plot_square(t, color_vector, ax=ax, **kwargs) + elif self._mesh_type == "hexagonal": + ax = self._plot_hexagonal(t, color_vector, ax=ax, **kwargs) + + return ax def _plot_square( self,