Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: make error message more clear when value mismatch #36

Merged
merged 2 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ max-line-length = 120
max-doc-length = 100
select = B,C,E,F,W,Y,SIM
ignore =
# E203: whitespace before ':'
# W503: line break before binary operator
# W504: line break after binary operator
# format by black
E203,W503,W504,
# E501: line too long
# W505: doc line too long
# too long docstring due to long example blocks
E501,W505,
# W503: line break before binary operator
# W504: line break after binary operator
# format by black
W503,W504,
per-file-ignores =
# F401: module imported but unused
# intentionally unused imports
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
- id: clang-format
stages: [commit, push, manual]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.249
rev: v0.0.252
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Make error message more clear when value mismatch by [@XuehaiPan](https://github.com/XuehaiPan) in [#36](https://github.com/metaopt/optree/pull/36).
- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#33](https://github.com/metaopt/optree/pull/33) and [#34](https://github.com/metaopt/optree/pull/34).

### Changed
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ There are several key attributes of the pytree type registry:

# Children are also `np.ndarray`s, recurse without termination condition.
>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy3')
RecursionError: maximum recursion depth exceeded during flattening the tree
Traceback (most recent call last):
...
RecursionError: Maximum recursion depth exceeded during flattening the tree.

>>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch1')
(
Expand All @@ -398,7 +400,9 @@ There are several key attributes of the pytree type registry:

# Children are also `torch.Tensor`s, recurse without termination condition.
>>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch2')
RecursionError: maximum recursion depth exceeded during flattening the tree
Traceback (most recent call last):
...
RecursionError: Maximum recursion depth exceeded during flattening the tree.
```

### `None` is Non-leaf Node vs. `None` is Leaf
Expand Down Expand Up @@ -451,6 +455,8 @@ OrderedDict([
])

>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
Traceback (most recent call last):
...
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType

>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
Expand Down Expand Up @@ -484,6 +490,8 @@ The keys are sorted in ascending order by `key=lambda k: k` first if capable oth
>>> sorted({1: 2, 1.5: 1}.keys())
[1, 1.5]
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys())
Traceback (most recent call last):
...
TypeError: '<' not supported between instances of 'int' and 'str'
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys(), key=lambda k: (k.__class__.__qualname__, k))
[1.5, 1, 'a']
Expand Down
31 changes: 17 additions & 14 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,22 @@ inline void AssertExact(const py::handle& object) {
template <>
inline void AssertExact<py::list>(const py::handle& object) {
if (!PyList_CheckExact(object.ptr())) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat("Expected list, got %s.", py::repr(object)));
throw std::invalid_argument(
absl::StrFormat("Expected an instance of list, got %s.", py::repr(object)));
}
}
template <>
inline void AssertExact<py::tuple>(const py::handle& object) {
if (!PyTuple_CheckExact(object.ptr())) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
throw std::invalid_argument(
absl::StrFormat("Expected an instance of tuple, got %s.", py::repr(object)));
}
}
template <>
inline void AssertExact<py::dict>(const py::handle& object) {
if (!PyDict_CheckExact(object.ptr())) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat("Expected dict, got %s.", py::repr(object)));
throw std::invalid_argument(
absl::StrFormat("Expected an instance of dict, got %s.", py::repr(object)));
}
}

Expand Down Expand Up @@ -357,29 +360,29 @@ inline bool IsNamedTupleClass(const py::handle& type) {
inline bool IsNamedTuple(const py::handle& object) { return IsNamedTupleClass(object.get_type()); }
inline void AssertExactNamedTuple(const py::handle& object) {
if (!IsNamedTuple(object)) [[unlikely]] {
throw std::invalid_argument(
absl::StrFormat("Expected collections.namedtuple, got %s.", py::repr(object)));
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of collections.namedtuple, got %s.", py::repr(object)));
}
}

inline void AssertExactOrderedDict(const py::handle& object) {
if (!object.get_type().is(PyOrderedDictTypeObject)) [[unlikely]] {
throw std::invalid_argument(
absl::StrFormat("Expected collections.OrderedDict, got %s.", py::repr(object)));
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of collections.OrderedDict, got %s.", py::repr(object)));
}
}

inline void AssertExactDefaultDict(const py::handle& object) {
if (!object.get_type().is(PyDefaultDictTypeObject)) [[unlikely]] {
throw std::invalid_argument(
absl::StrFormat("Expected collections.defaultdict, got %s.", py::repr(object)));
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of collections.defaultdict, got %s.", py::repr(object)));
}
}

inline void AssertExactDeque(const py::handle& object) {
if (!object.get_type().is(PyDequeTypeObject)) [[unlikely]] {
throw std::invalid_argument(
absl::StrFormat("Expected collections.deque, got %s.", py::repr(object)));
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of collections.deque, got %s.", py::repr(object)));
}
}

Expand Down Expand Up @@ -414,8 +417,8 @@ inline bool IsStructSequence(const py::handle& object) {
}
inline void AssertExactStructSequence(const py::handle& object) {
if (!IsStructSequence(object)) [[unlikely]] {
throw std::invalid_argument(
absl::StrFormat("Expected StructSequence, got %s.", py::repr(object)));
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of StructSequence type, got %s.", py::repr(object)));
}
}
inline py::tuple StructSequenceGetFields(const py::handle& object) {
Expand All @@ -424,7 +427,7 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) {
type = object;
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw std::invalid_argument(
absl::StrFormat("Expected StructSequence type, got %s.", py::repr(object)));
absl::StrFormat("Expected a StructSequence type, got %s.", py::repr(object)));
}
} else {
type = object.get_type();
Expand Down
2 changes: 1 addition & 1 deletion optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
]

MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH
"""Maximum recursion depth for pytree traversal. It is 5000 on Unix systems and 2500 on Windows."""
"""Maximum recursion depth for pytree traversal. It is 5000 on Unix-like systems and 2500 on Windows."""
NONE_IS_NODE: bool = NONE_IS_NODE # literal constant
"""Literal constant that treats :data:`None` as a pytree non-leaf node."""
NONE_IS_LEAF: bool = NONE_IS_LEAF # literal constant
Expand Down
23 changes: 15 additions & 8 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
]

MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH
"""Maximum recursion depth for pytree traversal. It is 5000 on Unix systems and 2500 on Windows."""
"""Maximum recursion depth for pytree traversal. It is 5000 on Unix-like systems and 2500 on Windows."""
NONE_IS_NODE: bool = False # literal constant
"""Literal constant that treats :data:`None` as a pytree non-leaf node."""
NONE_IS_LEAF: bool = True # literal constant
Expand Down Expand Up @@ -690,6 +690,7 @@ def tree_transpose(
outer_treespec: PyTreeSpec,
inner_treespec: PyTreeSpec,
tree: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
) -> PyTree[T]: # PyTree[PyTree[T]]
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

Expand Down Expand Up @@ -718,6 +719,10 @@ def tree_transpose(
outer_treespec (PyTreeSpec): A treespec object representing the outer structure of the pytree.
inner_treespec (PyTreeSpec): A treespec object representing the inner structure of the pytree.
tree (pytree): A pytree to be transposed.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.

Returns:
A new pytree with the same structure as ``inner_treespec`` but with the value at each leaf
Expand All @@ -738,19 +743,21 @@ def tree_transpose(
f'Tree structures must have the same namespace. '
f'Got {outer_treespec.namespace!r} vs. {inner_treespec.namespace!r}.'
)

leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=outer_treespec.none_is_leaf,
namespace=outer_treespec.namespace or inner_treespec.namespace,
)
if treespec.num_leaves != inner_size * outer_size:
if treespec.num_leaves != outer_size * inner_size:
expected_treespec = outer_treespec.compose(inner_treespec)
raise TypeError(f'Tree structure mismatch:\n{treespec}\n != \n{expected_treespec}')
iter_leaves = iter(leaves)
raise TypeError(f'Tree structure mismatch; expected: {expected_treespec}, got: {treespec}.')

grouped = [
[next(iter_leaves) for _ in range(inner_size)]
for __ in range(outer_size)
] # fmt: skip
leaves[offset : offset + inner_size]
for offset in range(0, outer_size * inner_size, inner_size)
]
transposed = zip(*grouped)
subtrees = map(outer_treespec.unflatten, transposed)
return inner_treespec.unflatten(subtrees) # type: ignore[arg-type]
Expand Down Expand Up @@ -1086,7 +1093,7 @@ def broadcast_prefix(
>>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
Traceback (most recent call last):
...
ValueError: List arity mismatch: 4 != 3; list: [1, 2, 3, 4].
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
[1, 2, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ typing-modules = ["optree.typing"]
"setup.py" = [
"ANN", # flake8-annotations
]
"benchmark.py" = [
"PLW2901", # redefined-loop-name
]

[tool.ruff.flake8-annotations]
allow-star-arg-any = true
Expand Down
Loading