diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 78c1c6df4..349d3d442 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -15,6 +15,7 @@ import copy +import numpy as np import numpy.testing as npt import pytest import tensorflow as tf @@ -68,6 +69,26 @@ def test_dataset_raises_for_different_leading_shapes( Dataset(query_points, observations) +def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None: + query_points = tf.zeros((2, 2)) + observations = tf.zeros((2, 1)) + + query_points_var = tf.Variable( + initial_value=np.zeros((0, 2)), + shape=(None, 2), + dtype=tf.float64, + ) + observations_var = tf.Variable( + initial_value=np.zeros((0, 1)), + shape=(None, 1), + dtype=tf.float64, + ) + + Dataset(query_points=query_points_var, observations=observations) + Dataset(query_points=query_points, observations=observations_var) + Dataset(query_points=query_points_var, observations=observations_var) + + @pytest.mark.parametrize( "query_points_shape, observations_shape", [ diff --git a/trieste/data.py b/trieste/data.py index 6c979a30e..5897efdff 100644 --- a/trieste/data.py +++ b/trieste/data.py @@ -53,6 +53,7 @@ def __post_init__(self) -> None: self.query_points.shape[:-1] != self.observations.shape[:-1] # can't check dynamic shapes, so trust that they're ok (if not, they'll fail later) and None not in self.query_points.shape[:-1] + and None not in self.observations.shape[:-1] ): raise ValueError( f"Leading shapes of query_points and observations must match. Got shapes"