Skip to content

Commit

Permalink
Fix a bug in fallback function when processing 2D mono
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Jul 9, 2024
1 parent ba9115c commit 05cba03
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion numpy_rms/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ def rms_numpy(a: NDArray, window_size: int) -> NDArray:
output_i = 0
for offset in range(0, end_index, window_size):
rms = calculate_rms(a[..., offset : offset + window_size])
output_array[output_i] = rms
output_array[..., output_i] = rms
output_i += 1
return output_array
10 changes: 10 additions & 0 deletions tests/test_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ def test_rms_large_array(benchmark):
assert rms[1] == pytest.approx(7637.1353)


def test_rms_2d_mono():
arr = np.ones(shape=(1, 1000), dtype=np.float32)
rms = numpy_rms.rms(arr, window_size=50)
rms_numpy_fallback = rms_numpy(arr, window_size=50)
assert rms.shape == (1, 20)
assert rms_numpy_fallback.shape == (1, 20)
assert rms[0, 0] == 1.0
assert_array_almost_equal(rms, rms_numpy_fallback)


def test_rms_numpy_fallback_large_array(benchmark):
arr = np.arange(100_000_000, dtype=np.float32)
rms = benchmark(rms_numpy, arr, window_size=5000)
Expand Down

0 comments on commit 05cba03

Please sign in to comment.