diff --git a/numpy_rms/fallback.py b/numpy_rms/fallback.py index c634128..c5e4177 100644 --- a/numpy_rms/fallback.py +++ b/numpy_rms/fallback.py @@ -19,6 +19,6 @@ def rms_numpy(a: NDArray, window_size: int) -> NDArray: output_i = 0 for offset in range(0, end_index, window_size): rms = calculate_rms(a[..., offset : offset + window_size]) - output_array[output_i] = rms + output_array[..., output_i] = rms output_i += 1 return output_array diff --git a/tests/test_rms.py b/tests/test_rms.py index ee04c9e..67e75fd 100644 --- a/tests/test_rms.py +++ b/tests/test_rms.py @@ -14,6 +14,16 @@ def test_rms_large_array(benchmark): assert rms[1] == pytest.approx(7637.1353) +def test_rms_2d_mono(): + arr = np.ones(shape=(1, 1000), dtype=np.float32) + rms = numpy_rms.rms(arr, window_size=50) + rms_numpy_fallback = rms_numpy(arr, window_size=50) + assert rms.shape == (1, 20) + assert rms_numpy_fallback.shape == (1, 20) + assert rms[0, 0] == 1.0 + assert_array_almost_equal(rms, rms_numpy_fallback) + + def test_rms_numpy_fallback_large_array(benchmark): arr = np.arange(100_000_000, dtype=np.float32) rms = benchmark(rms_numpy, arr, window_size=5000)