Skip to content

Commit

Permalink
Remove np.random.RandomState as input type
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Apr 27, 2024
1 parent 7284e73 commit 3e1419c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
16 changes: 8 additions & 8 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3781,7 +3781,7 @@ def to_pointcloud(
subsample: float | int = 1,
*,
as_array: Literal[False] = False,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
force_pixel_offset: Literal["center", "ul", "ur", "ll", "lr"] = "ul",
) -> NDArrayNum:
...
Expand All @@ -3796,7 +3796,7 @@ def to_pointcloud(
subsample: float | int = 1,
*,
as_array: Literal[True],
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
force_pixel_offset: Literal["center", "ul", "ur", "ll", "lr"] = "ul",
) -> Vector:
...
Expand All @@ -3811,7 +3811,7 @@ def to_pointcloud(
subsample: float | int = 1,
*,
as_array: bool = False,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
force_pixel_offset: Literal["center", "ul", "ur", "ll", "lr"] = "ul",
) -> NDArrayNum | Vector:
...
Expand All @@ -3824,7 +3824,7 @@ def to_pointcloud(
auxiliary_column_names: list[str] | None = None,
subsample: float | int = 1,
as_array: bool = False,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
force_pixel_offset: Literal["center", "ul", "ur", "ll", "lr"] = "ul",
) -> NDArrayNum | Vector:
"""
Expand Down Expand Up @@ -4106,7 +4106,7 @@ def subsample(
subsample: int | float,
return_indices: Literal[False] = False,
*,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> NDArrayNum:
...

Expand All @@ -4116,7 +4116,7 @@ def subsample(
subsample: int | float,
return_indices: Literal[True],
*,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> tuple[NDArrayNum, ...]:
...

Expand All @@ -4125,15 +4125,15 @@ def subsample(
self,
subsample: float | int,
return_indices: bool = False,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> NDArrayNum | tuple[NDArrayNum, ...]:
...

def subsample(
self,
subsample: float | int,
return_indices: bool = False,
random_state: np.random.RandomState | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> NDArrayNum | tuple[NDArrayNum, ...]:
"""
Randomly sample the raster. Only valid values are considered.
Expand Down
15 changes: 5 additions & 10 deletions geoutils/raster/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def subsample_array(
subsample: float | int,
return_indices: Literal[False] = False,
*,
random_state: np.random.RandomState | np.random.Generator | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> NDArrayNum:
...

Expand All @@ -27,7 +27,7 @@ def subsample_array(
subsample: float | int,
return_indices: Literal[True],
*,
random_state: np.random.RandomState | np.random.Generator | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> tuple[NDArrayNum, ...]:
...

Expand All @@ -37,7 +37,7 @@ def subsample_array(
array: NDArrayNum | MArrayNum,
subsample: float | int,
return_indices: bool = False,
random_state: np.random.RandomState | np.random.Generator | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> NDArrayNum | tuple[NDArrayNum, ...]:
...

Expand All @@ -46,7 +46,7 @@ def subsample_array(
array: NDArrayNum | MArrayNum,
subsample: float | int,
return_indices: bool = False,
random_state: np.random.RandomState | np.random.Generator | int | None = None,
random_state: int | np.random.Generator | None = None,
) -> NDArrayNum | tuple[NDArrayNum, ...]:
"""
Randomly subsample a 1D or 2D array by a sampling factor, taking only non NaN/masked values.
Expand All @@ -60,12 +60,7 @@ def subsample_array(
:returns: The subsampled array (1D) or the indices to extract (same shape as input array)
"""
# Define state for random sampling (to fix results during testing)
if random_state is None:
rng: np.random.RandomState | np.random.Generator = np.random.default_rng()
elif isinstance(random_state, (np.random.RandomState, np.random.Generator)):
rng = random_state
else:
rng = np.random.default_rng(random_state)
rng = np.random.default_rng(random_state)

# Remove invalid values and flatten array
mask = get_mask_from_array(array) # -> need to remove .squeeze in get_mask
Expand Down

0 comments on commit 3e1419c

Please sign in to comment.