Skip to content

Commit

Permalink
test(integration): add tests for integration
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 5, 2023
1 parent df7b64c commit f9df9f5
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 5 deletions.
8 changes: 7 additions & 1 deletion optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,13 @@ def _unravel_pytree(
return tree_unflatten(treespec, unravel_func(flat))


def _unravel_empty(_: Array) -> list[ArrayLike]:
def _unravel_empty(flat: Array) -> list[ArrayLike]:
if jnp.shape(flat) != (0,):
raise ValueError(

Check warning on line 164 in optree/integration/jax.py

View check run for this annotation

Codecov / codecov/patch

optree/integration/jax.py#L164

Added line #L164 was not covered by tests
f'The unravel function expected an array of shape {(0,)}, '
f'got shape {jnp.shape(flat)}.',
)

return []


Expand Down
7 changes: 6 additions & 1 deletion optree/integration/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ def _unravel_pytree(
return tree_unflatten(treespec, unravel_func(flat))


def _unravel_empty(_: np.ndarray) -> list[np.ndarray]:
def _unravel_empty(flat: np.ndarray) -> list[np.ndarray]:
if np.shape(flat) != (0,):
raise ValueError(

Check warning on line 93 in optree/integration/numpy.py

View check run for this annotation

Codecov / codecov/patch

optree/integration/numpy.py#L93

Added line #L93 was not covered by tests
f'The unravel function expected an array of shape {(0,)}, '
f'got shape {np.shape(flat)}.',
)
return []


Expand Down
14 changes: 11 additions & 3 deletions optree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ def _unravel_pytree(
return tree_unflatten(treespec, unravel_func(flat))


def _unravel_empty(_: torch.Tensor) -> list[torch.Tensor]:
def _unravel_empty(flat: torch.Tensor) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)}.')
if flat.shape != (0,):
raise ValueError(

Check warning on line 92 in optree/integration/torch.py

View check run for this annotation

Codecov / codecov/patch

optree/integration/torch.py#L92

Added line #L92 was not covered by tests
f'The unravel function expected a tensor of shape {(0,)}, got shape {flat.shape}.',
)
return []


Expand Down Expand Up @@ -126,9 +132,11 @@ def _unravel_leaves_single_dtype(
shapes: tuple[tuple[int, ...]],
flat: torch.Tensor,
) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)}.')
if flat.shape != (sum(sizes),):
raise ValueError(
f'The unravel function expected an array of shape {(sum(sizes),)}, '
f'The unravel function expected a tensor of shape {(sum(sizes),)}, '
f'got shape {flat.shape}.',
)

Expand All @@ -147,7 +155,7 @@ def _unravel_leaves(
raise ValueError(f'Expected a tensor to unravel, got {type(flat)}.')
if flat.shape != (sum(sizes),):
raise ValueError(
f'The unravel function expected an array of shape {(sum(sizes),)}, '
f'The unravel function expected a tensor of shape {(sum(sizes),)}, '
f'got shape {flat.shape}.',
)
if flat.dtype != to_dtype:
Expand Down
169 changes: 169 additions & 0 deletions tests/integration/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# pylint: disable=missing-function-docstring,wrong-import-position,wrong-import-order

import pytest


torch = pytest.importorskip('torch')

import random

import torch

import optree
from helpers import LEAVES, TREES, parametrize


@parametrize(tree=list(TREES + LEAVES))
def test_tree_ravel(tree):
random.seed(0)

def replace_leaf(_):
candidates = [
torch.tensor(random.randint(-100, 100)),
torch.tensor(random.uniform(-100.0, 100.0)),
]

shapes = [
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
]
dtypes = [
torch.float32,
torch.float64,
torch.int32,
torch.int64,
]
for dtype in dtypes:
candidates.extend(
(5.0 * (2.0 * torch.randn(size=shape) - 1.0)).to(dtype) for shape in shapes
)

return random.choice(candidates)

tree = optree.tree_map(replace_leaf, tree)
flat, unravel_func = optree.integration.torch.tree_ravel(tree)

leaves, treespec = optree.tree_flatten(tree)
assert flat.numel() == sum(leaf.numel() for leaf in leaves)
assert flat.shape == (flat.numel(),)

reconstructed = unravel_func(flat)
reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed)
assert reconstructed_treespec == treespec
assert len(leaves) == len(reconstructed_leaves)
for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves):
assert torch.is_tensor(leaf)
assert torch.is_tensor(reconstructed_leaf)
assert torch.allclose(leaf, reconstructed_leaf)
assert leaf.dtype == reconstructed_leaf.dtype
assert leaf.shape == reconstructed_leaf.shape

with pytest.raises(ValueError, match=r'Expected a tensor to unravel, got .*\.'):
unravel_func(1)

if len(leaves) > 0:
with pytest.raises(
ValueError,
match=r'The unravel function expected a tensor of shape .*, got .*\.',
):
unravel_func(flat.reshape((-1, 1)))
with pytest.raises(
ValueError,
match=r'The unravel function expected a tensor of shape .*, got .*\.',
):
unravel_func(torch.cat([flat, torch.zeros((1,))]))

if all(leaf.dtype == flat.dtype for leaf in leaves):
unravel_func(flat.to(torch.complex128))
else:
with pytest.raises(
ValueError,
match=r'The unravel function expected a tensor of dtype .*, got dtype .*\.',
):
unravel_func(flat.to(torch.complex128))


@parametrize(tree=list(TREES + LEAVES))
def test_tree_ravel_single_dtype(tree):
random.seed(0)
dtype = torch.float16
default_dtype = torch.tensor([]).dtype

def replace_leaf(_):
candidates = []
shapes = [
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
]
candidates.extend(
(5.0 * (2.0 * torch.randn(size=shape) - 1.0)).to(dtype) for shape in shapes
)
return random.choice(candidates)

tree = optree.tree_map(replace_leaf, tree)
flat, unravel_func = optree.integration.torch.tree_ravel(tree)

leaves, treespec = optree.tree_flatten(tree)
assert flat.dtype == dtype if leaves else default_dtype
assert flat.numel() == sum(leaf.numel() for leaf in leaves)
assert flat.shape == (flat.numel(),)

reconstructed = unravel_func(flat)
reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed)
assert reconstructed_treespec == treespec
assert len(leaves) == len(reconstructed_leaves)
for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves):
assert torch.is_tensor(leaf)
assert torch.is_tensor(reconstructed_leaf)
assert torch.allclose(leaf, reconstructed_leaf)
assert leaf.dtype == reconstructed_leaf.dtype
assert leaf.shape == reconstructed_leaf.shape

with pytest.raises(ValueError, match=r'Expected a tensor to unravel, got .*\.'):
unravel_func(1)

if len(leaves) > 0:
with pytest.raises(
ValueError,
match=r'The unravel function expected a tensor of shape .*, got .*\.',
):
unravel_func(flat.reshape((-1, 1)))
with pytest.raises(
ValueError,
match=r'The unravel function expected a tensor of shape .*, got .*\.',
):
unravel_func(torch.cat([flat, torch.zeros((1,))]))

unravel_func(flat.to(torch.complex128))


def test_tree_ravel_non_tensor():
with pytest.raises(ValueError, match=r'All leaves must be tensors\.'):
optree.integration.torch.tree_ravel(1)

with pytest.raises(ValueError, match=r'All leaves must be tensors\.'):
optree.integration.torch.tree_ravel((1, 2))

with pytest.raises(ValueError, match=r'All leaves must be tensors\.'):
optree.integration.torch.tree_ravel((torch.tensor(1), 2))

optree.integration.torch.tree_ravel((torch.tensor(1), torch.tensor(2)))

0 comments on commit f9df9f5

Please sign in to comment.