From 024ed95d9fa98930915bd584d357f769a989d1d4 Mon Sep 17 00:00:00 2001 From: Alessandro Vullo Date: Tue, 13 Aug 2024 13:26:04 +0100 Subject: [PATCH 1/4] Handle the case where either query points or observations have unspecified leading dimension. --- tests/unit/test_data.py | 43 +++++++++++++++++++++++++++++++++++++++++ trieste/data.py | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 78c1c6df42..be10470990 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -13,8 +13,10 @@ # limitations under the License. from __future__ import annotations +import contextlib import copy +import numpy as np import numpy.testing as npt import pytest import tensorflow as tf @@ -68,6 +70,47 @@ def test_dataset_raises_for_different_leading_shapes( Dataset(query_points, observations) +def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None: + @contextlib.contextmanager + def does_not_raise(): + try: + yield + except Exception as e: + pytest.fail(f"An exception was raised: {e}") + + query_points = tf.zeros((2, 2)) + observations = tf.zeros((2, 1)) + + query_points_var = tf.Variable( + initial_value=np.zeros((0, 2)), + shape=(None, 2), + dtype=tf.float64, + ) + observations_var = tf.Variable( + initial_value=np.zeros((0, 1)), + shape=(None, 1), + dtype=tf.float64, + ) + + with does_not_raise(): + Dataset( + query_points=query_points_var, + observations=observations + ) + + with does_not_raise(): + Dataset( + query_points=query_points, + observations=observations_var + ) + + with does_not_raise(): + Dataset( + query_points=query_points_var, + observations=observations_var + ) + + @pytest.mark.parametrize( "query_points_shape, observations_shape", [ diff --git a/trieste/data.py b/trieste/data.py index 6c979a30e0..3a505997a2 100644 --- a/trieste/data.py +++ b/trieste/data.py @@ -52,7 +52,7 @@ def __post_init__(self) -> None: if ( self.query_points.shape[:-1] != self.observations.shape[:-1] # can't check dynamic shapes, so trust that they're ok (if not, they'll fail later) - and None not in self.query_points.shape[:-1] + and None not in self.query_points.shape[:-1] and None not in self.observations.shape[:-1] ): raise ValueError( f"Leading shapes of query_points and observations must match. Got shapes" From 4b90bb18e435017b5dee494b3ee3ddee3661abd8 Mon Sep 17 00:00:00 2001 From: Alessandro Vullo Date: Tue, 13 Aug 2024 13:31:18 +0100 Subject: [PATCH 2/4] Typedef. --- tests/unit/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index be10470990..e183a40940 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -72,7 +72,7 @@ def test_dataset_raises_for_different_leading_shapes( def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None: @contextlib.contextmanager - def does_not_raise(): + def does_not_raise() -> None: try: yield except Exception as e: From 5e0fee45df7a27233fb8bf155fb11f812644972a Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 13 Aug 2024 14:01:01 +0100 Subject: [PATCH 3/4] Format and typing --- tests/unit/test_data.py | 27 +++------------------------ trieste/data.py | 3 ++- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index e183a40940..2e0dc5d32c 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -71,13 +71,6 @@ def test_dataset_raises_for_different_leading_shapes( def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None: - @contextlib.contextmanager - def does_not_raise() -> None: - try: - yield - except Exception as e: - pytest.fail(f"An exception was raised: {e}") - query_points = tf.zeros((2, 2)) observations = tf.zeros((2, 1)) @@ -92,23 +85,9 @@ def does_not_raise() -> None: dtype=tf.float64, ) - with does_not_raise(): - Dataset( - query_points=query_points_var, - observations=observations - ) - - with does_not_raise(): - Dataset( - query_points=query_points, - observations=observations_var - ) - - with does_not_raise(): - Dataset( - query_points=query_points_var, - observations=observations_var - ) + Dataset(query_points=query_points_var, observations=observations) + Dataset(query_points=query_points, observations=observations_var) + Dataset(query_points=query_points_var, observations=observations_var) @pytest.mark.parametrize( diff --git a/trieste/data.py b/trieste/data.py index 3a505997a2..5897efdff7 100644 --- a/trieste/data.py +++ b/trieste/data.py @@ -52,7 +52,8 @@ def __post_init__(self) -> None: if ( self.query_points.shape[:-1] != self.observations.shape[:-1] # can't check dynamic shapes, so trust that they're ok (if not, they'll fail later) - and None not in self.query_points.shape[:-1] and None not in self.observations.shape[:-1] + and None not in self.query_points.shape[:-1] + and None not in self.observations.shape[:-1] ): raise ValueError( f"Leading shapes of query_points and observations must match. Got shapes" From 5b95f2bed37ae9c51b1f9af3c50700cbfce8f2ce Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 13 Aug 2024 14:02:27 +0100 Subject: [PATCH 4/4] Stray import --- tests/unit/test_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 2e0dc5d32c..349d3d442c 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import contextlib import copy import numpy as np