From 10710283d5a6d2c70b4d8b6cb0f146d7ed349033 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 1 Oct 2024 20:47:07 -0700 Subject: [PATCH] Refactor histogram op --- keras/api/_tf_keras/keras/layers/__init__.py | 3 + keras/api/_tf_keras/keras/ops/__init__.py | 2 +- .../api/_tf_keras/keras/ops/numpy/__init__.py | 1 + keras/api/layers/__init__.py | 3 + keras/api/ops/__init__.py | 2 +- keras/api/ops/numpy/__init__.py | 1 + keras/src/backend/jax/math.py | 4 - keras/src/backend/jax/numpy.py | 4 + keras/src/backend/numpy/math.py | 4 - keras/src/backend/numpy/numpy.py | 4 + keras/src/backend/tensorflow/math.py | 28 ----- keras/src/backend/tensorflow/numpy.py | 28 +++++ keras/src/backend/torch/math.py | 5 - keras/src/backend/torch/numpy.py | 5 + keras/src/ops/math.py | 86 -------------- keras/src/ops/math_test.py | 112 ------------------ keras/src/ops/numpy.py | 96 +++++++++++++++ keras/src/ops/numpy_test.py | 112 ++++++++++++++++++ 18 files changed, 259 insertions(+), 241 deletions(-) diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 7c905b9efad..526fbc65614 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -160,6 +160,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear, +) from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( RandomTranslation, ) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index a1730855e11..20cf46889d2 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -47,7 +47,6 @@ from keras.src.ops.math import extract_sequences from keras.src.ops.math import fft from keras.src.ops.math import fft2 -from keras.src.ops.math import histogram from keras.src.ops.math import in_top_k from keras.src.ops.math import irfft from keras.src.ops.math import istft @@ -160,6 +159,7 @@ from keras.src.ops.numpy import get_item from keras.src.ops.numpy import greater from keras.src.ops.numpy import greater_equal +from keras.src.ops.numpy import histogram from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index 9be9476190b..311180adb41 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -70,6 +70,7 @@ from keras.src.ops.numpy import get_item from keras.src.ops.numpy import greater from keras.src.ops.numpy import greater_equal +from keras.src.ops.numpy import histogram from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 2c1b3d57643..8e654936303 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -160,6 +160,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear, +) from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( RandomTranslation, ) diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index a1730855e11..20cf46889d2 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -47,7 +47,6 @@ from keras.src.ops.math import extract_sequences from keras.src.ops.math import fft from keras.src.ops.math import fft2 -from keras.src.ops.math import histogram from keras.src.ops.math import in_top_k from keras.src.ops.math import irfft from keras.src.ops.math import istft @@ -160,6 +159,7 @@ from keras.src.ops.numpy import get_item from keras.src.ops.numpy import greater from keras.src.ops.numpy import greater_equal +from keras.src.ops.numpy import histogram from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index 9be9476190b..311180adb41 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -70,6 +70,7 @@ from keras.src.ops.numpy import get_item from keras.src.ops.numpy import greater from keras.src.ops.numpy import greater_equal +from keras.src.ops.numpy import histogram from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 11c96086ced..18ba91862a9 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -294,7 +294,3 @@ def logdet(x): # `np.log(np.linalg.det(x))`. See # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html return slogdet(x)[1] - - -def histogram(x, bins, range): - return jnp.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 4f570502cba..7251333d7d6 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -1246,3 +1246,7 @@ def slogdet(x): def argpartition(x, kth, axis=-1): return jnp.argpartition(x, kth, axis) + + +def histogram(x, bins, range): + return jnp.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/numpy/math.py b/keras/src/backend/numpy/math.py index a40cd569578..f9448c92b93 100644 --- a/keras/src/backend/numpy/math.py +++ b/keras/src/backend/numpy/math.py @@ -316,7 +316,3 @@ def logdet(x): # In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html return slogdet(x)[1] - - -def histogram(x, bins, range): - return np.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index 3ee17fd9e6f..f6817958772 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -1177,3 +1177,7 @@ def slogdet(x): def argpartition(x, kth, axis=-1): return np.argpartition(x, kth, axis).astype("int32") + + +def histogram(x, bins, range): + return np.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/tensorflow/math.py b/keras/src/backend/tensorflow/math.py index 3e50fbc9773..f034cf429e1 100644 --- a/keras/src/backend/tensorflow/math.py +++ b/keras/src/backend/tensorflow/math.py @@ -370,31 +370,3 @@ def norm(x, ord=None, axis=None, keepdims=False): def logdet(x): x = convert_to_tensor(x) return tf.linalg.logdet(x) - - -def histogram(x, bins, range): - """ - Computes a histogram of the data tensor `x` using TensorFlow. - The `tf.histogram_fixed_width()` and `tf.histogram_fixed_width_bins()` - methods yielded slight numerical differences on some edge cases. - """ - - x = tf.convert_to_tensor(x, dtype=x.dtype) - - # Handle the range argument - if range is None: - min_val = tf.reduce_min(x) - max_val = tf.reduce_max(x) - else: - min_val, max_val = range - - x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val)) - bin_edges = tf.linspace(min_val, max_val, bins + 1) - bin_edges_list = bin_edges.numpy().tolist() - bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1]) - - bin_counts = tf.math.bincount( - bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype - ) - - return bin_counts, bin_edges diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 4dca87dbacc..a137f414acf 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2546,3 +2546,31 @@ def argpartition(x, kth, axis=-1): out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1) return swapaxes(out, -1, axis) + + +def histogram(x, bins, range): + """Computes a histogram of the data tensor `x`. + + Note: the `tf.histogram_fixed_width()` and + `tf.histogram_fixed_width_bins()` functions + yield slight numerical differences for some edge cases. + """ + + x = tf.convert_to_tensor(x, dtype=x.dtype) + + # Handle the range argument + if range is None: + min_val = tf.reduce_min(x) + max_val = tf.reduce_max(x) + else: + min_val, max_val = range + + x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val)) + bin_edges = tf.linspace(min_val, max_val, bins + 1) + bin_edges_list = bin_edges.numpy().tolist() + bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1]) + + bin_counts = tf.math.bincount( + bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype + ) + return bin_counts, bin_edges diff --git a/keras/src/backend/torch/math.py b/keras/src/backend/torch/math.py index e05d358e901..e2e80e9358c 100644 --- a/keras/src/backend/torch/math.py +++ b/keras/src/backend/torch/math.py @@ -419,8 +419,3 @@ def norm(x, ord=None, axis=None, keepdims=False): def logdet(x): x = convert_to_tensor(x) return torch.logdet(x) - - -def histogram(x, bins, range): - hist_result = torch.histogram(x, bins=bins, range=range) - return hist_result.hist, hist_result.bin_edges diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index c2fffd460c1..a8726f0b2a9 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -1701,3 +1701,8 @@ def set_to_zero(a, i): top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1] out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1) return cast(torch.transpose(out, -1, axis), "int32") + + +def histogram(x, bins, range): + hist_result = torch.histogram(x, bins=bins, range=range) + return hist_result.hist, hist_result.bin_edges diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py index 749142b6a6d..fd0a41d5177 100644 --- a/keras/src/ops/math.py +++ b/keras/src/ops/math.py @@ -971,89 +971,3 @@ def logdet(x): if any_symbolic_tensors((x,)): return Logdet().symbolic_call(x) return backend.math.logdet(x) - - -class Histogram(Operation): - def __init__(self, bins=10, range=None): - super().__init__() - - if not isinstance(bins, int): - raise TypeError("bins must be of type `int`") - if bins < 0: - raise ValueError("`bins` should be a non-negative integer") - - if range: - if len(range) < 2 or not isinstance(range, tuple): - raise ValueError("range must be a tuple of two elements") - - if range[1] < range[0]: - raise ValueError( - "The second element of range must be greater than the first" - ) - - self.bins = bins - self.range = range - - def call(self, x): - x = backend.convert_to_tensor(x) - if len(x.shape) > 1: - raise ValueError("Input tensor must be 1-dimensional") - return backend.math.histogram(x, bins=self.bins, range=self.range) - - def compute_output_spec(self, x): - return ( - KerasTensor(shape=(self.bins,), dtype=x.dtype), - KerasTensor(shape=(self.bins + 1,), dtype=x.dtype), - ) - - -@keras_export("keras.ops.histogram") -def histogram(x, bins=10, range=None): - """Computes a histogram of the data tensor `x`. - - Args: - x: Input tensor. - bins: An integer representing the number of histogram bins. - Defaults to 10. - range: A tuple representing the lower and upper range of the bins. - If not specified, it will use the min and max of `x`. - - Returns: - A tuple containing: - - A tensor representing the counts of elements in each bin. - - A tensor representing the bin edges. - - Example: - - ``` - >>> nput_tensor = np.random.rand(8) - >>> keras.ops.histogram(input_tensor) - (array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32), - array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262, - 0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101, - 0.85892869])) - ``` - - """ - - if not isinstance(bins, int): - raise TypeError("bins must be of type `int`") - if bins < 0: - raise ValueError("`bins` should be a non-negative integer") - - if range: - if len(range) < 2 or not isinstance(range, tuple): - raise ValueError("range must be a tuple of two elements") - - if range[1] < range[0]: - raise ValueError( - "The second element of range must be greater than the first" - ) - - if any_symbolic_tensors((x,)): - return Histogram(bins=bins, range=range).symbolic_call(x) - - x = backend.convert_to_tensor(x) - if len(x.shape) > 1: - raise ValueError("Input tensor must be 1-dimensional") - return backend.math.histogram(x, bins=bins, range=range) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 09bcb9503fd..09c87514c78 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -1468,115 +1468,3 @@ def test_istft_invalid_window_shape_2D_inputs(self): fft_length, window=incorrect_window, ) - - -class HistogramTest(testing.TestCase): - def test_histogram_default_args(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(8) - - # Expected output - expected_counts, expected_edges = np.histogram(input_tensor) - - counts, edges = hist_op(input_tensor) - - self.assertEqual(counts.shape, expected_counts.shape) - self.assertAllClose(counts, expected_counts) - self.assertEqual(edges.shape, expected_edges.shape) - self.assertAllClose(edges, expected_edges) - - def test_histogram_custom_bins(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(8) - bins = 5 - - # Expected output - expected_counts, expected_edges = np.histogram(input_tensor, bins=bins) - - counts, edges = hist_op(input_tensor, bins=bins) - - self.assertEqual(counts.shape, expected_counts.shape) - self.assertAllClose(counts, expected_counts) - self.assertEqual(edges.shape, expected_edges.shape) - self.assertAllClose(edges, expected_edges) - - def test_histogram_custom_range(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(10) - range_specified = (2, 8) - - # Expected output - expected_counts, expected_edges = np.histogram( - input_tensor, range=range_specified - ) - - counts, edges = hist_op(input_tensor, range=range_specified) - - self.assertEqual(counts.shape, expected_counts.shape) - self.assertAllClose(counts, expected_counts) - self.assertEqual(edges.shape, expected_edges.shape) - self.assertAllClose(edges, expected_edges) - - def test_histogram_symbolic_input(self): - hist_op = kmath.histogram - input_tensor = KerasTensor(shape=(None,), dtype="float32") - - counts, edges = hist_op(input_tensor) - - self.assertEqual(counts.shape, (10,)) - self.assertEqual(edges.shape, (11,)) - - def test_histogram_non_integer_bins_raises_error(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(8) - - with self.assertRaisesRegex( - ValueError, "`bins` should be a non-negative integer" - ): - hist_op(input_tensor, bins=-5) - - def test_histogram_range_validation(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(8) - - with self.assertRaisesRegex( - ValueError, "range must be a tuple of two elements" - ): - hist_op(input_tensor, range=(1,)) - - with self.assertRaisesRegex( - ValueError, - "The second element of range must be greater than the first", - ): - hist_op(input_tensor, range=(5, 1)) - - def test_histogram_large_values(self): - hist_op = kmath.histogram - input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10]) - - counts, edges = hist_op(input_tensor, bins=5) - - expected_counts, expected_edges = np.histogram(input_tensor, bins=5) - - self.assertAllClose(counts, expected_counts) - self.assertAllClose(edges, expected_edges) - - def test_histogram_float_input(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(8) - - counts, edges = hist_op(input_tensor, bins=5) - - expected_counts, expected_edges = np.histogram(input_tensor, bins=5) - - self.assertAllClose(counts, expected_counts) - self.assertAllClose(edges, expected_edges) - - def test_histogram_high_dimensional_input(self): - hist_op = kmath.histogram - input_tensor = np.random.rand(3, 4, 5) - - with self.assertRaisesRegex( - ValueError, "Input tensor must be 1-dimensional" - ): - hist_op(input_tensor) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 9e1de5531ee..9a82cc982a7 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -6611,3 +6611,99 @@ def argpartition(x, kth, axis=-1): if any_symbolic_tensors((x,)): return Argpartition(kth, axis).symbolic_call(x) return backend.numpy.argpartition(x, kth, axis) + + +class Histogram(Operation): + def __init__(self, bins=10, range=None): + super().__init__() + + if not isinstance(bins, int): + raise TypeError("bins must be of type `int`") + if bins < 0: + raise ValueError("`bins` should be a non-negative integer") + + if range: + if len(range) < 2 or not isinstance(range, tuple): + raise ValueError("range must be a tuple of two elements") + + if range[1] < range[0]: + raise ValueError( + "The second element of range must be greater than the first" + ) + + self.bins = bins + self.range = range + + def call(self, x): + x = backend.convert_to_tensor(x) + if len(x.shape) > 1: + raise ValueError("Input tensor must be 1-dimensional") + return backend.math.histogram(x, bins=self.bins, range=self.range) + + def compute_output_spec(self, x): + return ( + KerasTensor(shape=(self.bins,), dtype=x.dtype), + KerasTensor(shape=(self.bins + 1,), dtype=x.dtype), + ) + + +@keras_export(["keras.ops.histogram", "keras.ops.numpy.histogram"]) +def histogram(x, bins=10, range=None): + """Computes a histogram of the data tensor `x`. + + Args: + x: Input tensor. + bins: An integer representing the number of histogram bins. + Defaults to 10. + range: A tuple representing the lower and upper range of the bins. + If not specified, it will use the min and max of `x`. + + Returns: + A tuple containing: + - A tensor representing the counts of elements in each bin. + - A tensor representing the bin edges. + + Example: + + ``` + >>> input_tensor = np.random.rand(8) + >>> keras.ops.histogram(input_tensor) + (array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32), + array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262, + 0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101, + 0.85892869])) + ``` + """ + if not isinstance(bins, int): + raise TypeError( + f"Argument `bins` must be of type `int`. Received: bins={bins}" + ) + if bins < 0: + raise ValueError( + "Argument `bins` should be a non-negative integer. " + f"Received: bins={bins}" + ) + + if range: + if len(range) < 2 or not isinstance(range, tuple): + raise ValueError( + "Argument `range` must be a tuple of two elements. " + f"Received: range={range}" + ) + + if range[1] < range[0]: + raise ValueError( + "The second element of `range` must be greater than the first. " + f"Received: range={range}" + ) + + if any_symbolic_tensors((x,)): + return Histogram(bins=bins, range=range).symbolic_call(x) + + x = backend.convert_to_tensor(x) + if len(x.shape) > 1: + raise ValueError( + "Input tensor must be 1-dimensional. " + f"Received: input.shape={x.shape}" + ) + return backend.numpy.histogram(x, bins=bins, range=range) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index a4143328bcf..0e94ee515e3 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -8332,3 +8332,115 @@ def test_zeros_like(self, dtype): standardize_dtype(knp.ZerosLike().symbolic_call(x).dtype), expected_dtype, ) + + +class HistogramTest(testing.TestCase): + def test_histogram_default_args(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + # Expected output + expected_counts, expected_edges = np.histogram(input_tensor) + + counts, edges = hist_op(input_tensor) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_custom_bins(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + bins = 5 + + # Expected output + expected_counts, expected_edges = np.histogram(input_tensor, bins=bins) + + counts, edges = hist_op(input_tensor, bins=bins) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_custom_range(self): + hist_op = knp.histogram + input_tensor = np.random.rand(10) + range_specified = (2, 8) + + # Expected output + expected_counts, expected_edges = np.histogram( + input_tensor, range=range_specified + ) + + counts, edges = hist_op(input_tensor, range=range_specified) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_symbolic_input(self): + hist_op = knp.histogram + input_tensor = KerasTensor(shape=(None,), dtype="float32") + + counts, edges = hist_op(input_tensor) + + self.assertEqual(counts.shape, (10,)) + self.assertEqual(edges.shape, (11,)) + + def test_histogram_non_integer_bins_raises_error(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + with self.assertRaisesRegex( + ValueError, "Argument `bins` should be a non-negative integer" + ): + hist_op(input_tensor, bins=-5) + + def test_histogram_range_validation(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + with self.assertRaisesRegex( + ValueError, "Argument `range` must be a tuple of two elements" + ): + hist_op(input_tensor, range=(1,)) + + with self.assertRaisesRegex( + ValueError, + "The second element of `range` must be greater than the first", + ): + hist_op(input_tensor, range=(5, 1)) + + def test_histogram_large_values(self): + hist_op = knp.histogram + input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10]) + + counts, edges = hist_op(input_tensor, bins=5) + + expected_counts, expected_edges = np.histogram(input_tensor, bins=5) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + def test_histogram_float_input(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + counts, edges = hist_op(input_tensor, bins=5) + + expected_counts, expected_edges = np.histogram(input_tensor, bins=5) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + def test_histogram_high_dimensional_input(self): + hist_op = knp.histogram + input_tensor = np.random.rand(3, 4, 5) + + with self.assertRaisesRegex( + ValueError, "Input tensor must be 1-dimensional" + ): + hist_op(input_tensor)