diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b5ea84b7..4a7ea443 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,10 +27,6 @@ repos: repo: https://github.com/myint/autoflake rev: v2.3.1 -- repo: https://github.com/asottile/setup-cfg-fmt - rev: v2.5.0 - hooks: - - id: setup-cfg-fmt - repo: https://github.com/asottile/reorder-python-imports rev: v3.13.0 diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 177548f6..4bde16c0 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -42,7 +42,7 @@ def test_enum(): inline-snapshot comes with a special implementation for the following types: -```python exec="1" +``` python exec="1" from inline_snapshot._code_repr import code_repr_dispatch, code_repr for name, obj in sorted( @@ -60,7 +60,7 @@ for name, obj in sorted( Container types like `dict` or `dataclass` need a special implementation because it is necessary that the implementation uses `repr()` for the child elements. -```python exec="1" result="python" +``` python exec="1" result="python" print('--8<-- "src/inline_snapshot/_code_repr.py:list"') ``` diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index 899ba13e..de32a4a0 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -33,9 +33,31 @@ Example: def test_something(): assert 2 + 40 == snapshot(42) ``` +## unmanaged snapshot parts +inline-snapshots manages everything inside `snapshot(...)`, which means that the developer should not change these parts, but there are cases where it is useful to give the developer a bit more control over the snapshot content. -## dirty-equals +Therefor some types will be ignored by inline-snapshot and will **not be updated or fixed**, even if they cause tests to fail. + +These types are: + +* dirty-equals expression +* dynamic code inside `Is(...)` +* and snapshots inside snapshots. + +inline-snapshot is able to handle these types inside the following containers: + +* list +* tuple +* dict +* namedtuple +* dataclass + + +### dirty-equals It might be, that larger snapshots with many lists and dictionaries contain some values which change frequently and are not relevant for the test. They might be part of larger data structures and be difficult to normalize. @@ -82,7 +104,7 @@ Example: inline-snapshot tries to change only the values that it needs to change in order to pass the equality comparison. This allows to replace parts of the snapshot with [dirty-equals](https://dirty-equals.helpmanual.io/latest/) expressions. -This expressions are preserved as long as the `==` comparison with them is `True`. +This expressions are preserved even if the `==` comparison with them is `False`. Example: @@ -159,8 +181,149 @@ Example: ) ``` -!!! note - The current implementation looks only into lists, dictionaries and tuples and not into the representation of other data structures. +### Is(...) + +`Is()` can be used to put runtime values inside snapshots. +It tells inline-snapshot that the developer wants control over some part of the snapshot. + + +``` python +from inline_snapshot import snapshot, Is + +current_version = "1.5" + + +def request(): + return {"data": "page data", "version": current_version} + + +def test_function(): + assert request() == snapshot( + {"data": "page data", "version": Is(current_version)} + ) +``` + +The `current_version` can now be changed without having to correct the snapshot. + +`Is()` can also be used when the snapshot is evaluated multiple times. + +=== "original code" + + ``` python + from inline_snapshot import snapshot, Is + + + def test_function(): + for c in "abc": + assert [c, "correct"] == snapshot([Is(c), "wrong"]) + ``` + +=== "--inline-snapshot=fix" + + ``` python hl_lines="6" + from inline_snapshot import snapshot, Is + + + def test_function(): + for c in "abc": + assert [c, "correct"] == snapshot([Is(c), "correct"]) + ``` + +### inner snapshots + +Snapshots can be used inside other snapshots in different use cases. + +#### conditional snapshots +It is possible to describe version specific parts of snapshots by replacing the specific part with `#!python snapshot() if some_condition else snapshot()`. +The test has to be executed in each specific condition to fill the snapshots. + +The following example shows how this can be used to run a tests with two different library versions: + +=== "my_lib v1" + + + ``` python + version = 1 + + + def get_schema(): + return [{"name": "var_1", "type": "int"}] + ``` + +=== "my_lib v2" + + + ``` python + version = 2 + + + def get_schema(): + return [{"name": "var_1", "type": "string"}] + ``` + + + +``` python +from inline_snapshot import snapshot +from my_lib import version, get_schema + + +def test_function(): + assert get_schema() == snapshot( + [ + { + "name": "var_1", + "type": snapshot("int") if version < 2 else snapshot("string"), + } + ] + ) +``` + +The advantage of this approach is that the test uses always the correct values for each library version. + +#### common snapshot parts + +Another usecase is the extraction of common snapshot parts into an extra snapshot: + + +``` python +from inline_snapshot import snapshot + + +def some_data(name): + return {"header": "really long header\n" * 5, "your name": name} + + +def test_function(): + + header = snapshot( + """\ +really long header +really long header +really long header +really long header +really long header +""" + ) + + assert some_data("Tom") == snapshot( + { + "header": header, + "your name": "Tom", + } + ) + + assert some_data("Bob") == snapshot( + { + "header": header, + "your name": "Bob", + } + ) +``` + +This simplifies test data and allows inline-snapshot to update your values if required. +It makes also sure that the header is the same in both cases. + ## pytest options diff --git a/docs/pytest.md b/docs/pytest.md index 5d825618..a7d34aca 100644 --- a/docs/pytest.md +++ b/docs/pytest.md @@ -11,7 +11,7 @@ inline-snapshot provides one pytest option with different flags (*create*, Snapshot comparisons return always `True` if you use one of the flags *create*, *fix* or *review*. This is necessary because the whole test needs to be run to fix all snapshots like in this case: -```python +``` python from inline_snapshot import snapshot @@ -30,7 +30,7 @@ def test_something(): Approve the changes of the given [category](categories.md). These flags can be combined with *report* and *review*. -```python title="test_something.py" +``` python title="test_something.py" from inline_snapshot import snapshot diff --git a/docs/testing.md b/docs/testing.md index b712b7aa..9f6b1a2b 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -59,7 +59,7 @@ The following example shows how you can use the `Example` class to test what inl report=snapshot( """\ Error: one snapshot is missing a value (--inline-snapshot=create) - You can also use --inline-snapshot=review to approve the changes interactiv\ + You can also use --inline-snapshot=review to approve the changes interactively\ """ ), ).run_pytest( # run with create flag and check the changed files diff --git a/pyproject.toml b/pyproject.toml index e51aecb4..7c0f4f02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,21 @@ extra-dependencies = [ "pydantic" ] +[tool.hatch.envs.default] +extra-dependencies = [ + "dirty-equals>=0.7.0", + "hypothesis>=6.75.5", + "mypy>=1.2.0", + "pyright>=1.1.359", + "pytest-subtests>=0.11.0", + "time-machine>=2.10.0", + "pydantic", + + "pytest-xdist", + "textual", + "mutmut @ {root:uri}/../mutmut" +] + [[tool.hatch.envs.types.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12","3.13"] @@ -144,3 +159,11 @@ venvPath = ".nox" [tool.scriv] format = "md" version = "literal: pyproject.toml: project.version" + +[tool.mutmut] +also_copy=[ + "src/inline_snapshot/py.typed", + "conftest.py", + "pyproject.toml", + "README.md" +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..0054d4d0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ + +-e ./mutants +dirty-equals>=0.7.0 +hypothesis>=6.75.5 +mypy>=1.2.0 +pyright>=1.1.359 +pytest-subtests>=0.11.0 +time-machine>=2.10.0 +pydantic + +pytest-xdist +-e ../mutmut diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index f06c02e6..74b443d4 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -3,9 +3,19 @@ from ._external import external from ._external import outsource from ._inline_snapshot import snapshot +from ._is import Is from ._types import Category from ._types import Snapshot -__all__ = ["snapshot", "external", "outsource", "customize_repr", "HasRepr"] +__all__ = [ + "snapshot", + "external", + "outsource", + "customize_repr", + "HasRepr", + "Is", + "Category", + "Snapshot", +] __version__ = "0.13.3" diff --git a/src/inline_snapshot/_adapter/__init__.py b/src/inline_snapshot/_adapter/__init__.py new file mode 100644 index 00000000..2f699011 --- /dev/null +++ b/src/inline_snapshot/_adapter/__init__.py @@ -0,0 +1,3 @@ +from .adapter import get_adapter_type + +__all__ = ("get_adapter_type",) diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py new file mode 100644 index 00000000..a1d58ac6 --- /dev/null +++ b/src/inline_snapshot/_adapter/adapter.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import ast +import typing +from dataclasses import is_dataclass + +from inline_snapshot._source_file import SourceFile + + +def get_adapter_type(value): + if is_dataclass(value): + from .dataclass_adapter import DataclassAdapter + + return DataclassAdapter + + if isinstance(value, list): + from .sequence_adapter import ListAdapter + + return ListAdapter + + if isinstance(value, tuple): + from .sequence_adapter import TupleAdapter + + return TupleAdapter + + if isinstance(value, dict): + from .dict_adapter import DictAdapter + + return DictAdapter + + from .value_adapter import ValueAdapter + + return ValueAdapter + + +class Item(typing.NamedTuple): + value: typing.Any + node: ast.expr + + +class Adapter: + context: SourceFile + + def __init__(self, context): + self.context = context + + def get_adapter(self, old_value, new_value) -> Adapter: + if type(old_value) is not type(new_value): + from .value_adapter import ValueAdapter + + return ValueAdapter(self.context) + + adapter_type = get_adapter_type(old_value) + if adapter_type is not None: + return adapter_type(self.context) + assert False + + def assign(self, old_value, old_node, new_value): + raise NotImplementedError diff --git a/src/inline_snapshot/_adapter/dataclass_adapter.py b/src/inline_snapshot/_adapter/dataclass_adapter.py new file mode 100644 index 00000000..ababf094 --- /dev/null +++ b/src/inline_snapshot/_adapter/dataclass_adapter.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import ast +import warnings +from dataclasses import fields +from dataclasses import MISSING + +from inline_snapshot._adapter.value_adapter import ValueAdapter + +from .._change import CallArg +from .._change import Delete +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import Item + + +class DataclassAdapter(Adapter): + + def items(self, value, node): + assert isinstance(node, ast.Call) + assert not node.args + assert all(kw.arg for kw in node.keywords) + + return [ + Item(value=self.argument(value, kw.arg), node=kw.value) + for kw in node.keywords + if kw.arg + ] + + def arguments(self, value): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + + if field.default != MISSING and field.default == field_value: + continue + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + continue + + kwargs[field.name] = field_value + + return ([], kwargs) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + def assign(self, old_value, old_node, new_value): + if old_node is None: + + value = yield from ValueAdapter(self.context).assign( + old_value, old_node, new_value + ) + return value + + assert isinstance(old_node, ast.Call) + + for kw in old_node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + new_args, new_kwargs = self.arguments(new_value) + + result_kwargs = {} + for kw in old_node.keywords: + if not kw.arg in new_kwargs: + # delete entries + yield Delete( + "fix", + self.context._source, + kw.value, + self.argument(old_value, kw.arg), + ) + + old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_kwargs.items(): + if key not in old_node_kwargs: + # add new values + to_insert.append((key, new_value_element)) + result_kwargs[key] = new_value_element + else: + node = old_node_kwargs[key] + + # check values with same keys + old_value_element = self.argument(old_value, key) + result_kwargs[key] = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, node, new_value_element) + + if to_insert: + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=self.context._value_to_code(value), + new_value=value, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=self.context._value_to_code(value), + new_value=value, + ) + + return type(old_value)(**result_kwargs) diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py new file mode 100644 index 00000000..a91a2cbc --- /dev/null +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import ast +import warnings + +from .._change import Delete +from .._change import DictInsert +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import Item + + +class DictAdapter(Adapter): + def items(self, value, node): + assert isinstance(node, ast.Dict) + + result = [] + + for value_key, node_key, node_value in zip( + value.keys(), node.keys, node.values + ): + if node_key is not None: + try: + # this is just a sanity check, dicts should be ordered + node_key = ast.literal_eval(node_key) + except Exception: + pass + else: + assert node_key == value_key + + result.append(Item(value=value[value_key], node=node_value)) + + return result + + def assign(self, old_value, old_node, new_value): + if old_node is not None: + assert isinstance(old_node, ast.Dict) + assert len(old_value) == len(old_node.keys) + + for key, value in zip(old_node.keys, old_node.values): + if key is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + for value, node in zip(old_value.keys(), old_node.keys): + + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except: + continue + assert node_value == value + + result = {} + for key, node in zip( + old_value.keys(), + (old_node.values if old_node is not None else [None] * len(old_value)), + ): + if not key in new_value: + # delete entries + yield Delete("fix", self.context._source, node, old_value[key]) + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_value.items(): + if key not in old_value: + # add new values + to_insert.append((key, new_value_element)) + result[key] = new_value_element + else: + if isinstance(old_node, ast.Dict): + node = old_node.values[list(old_value.keys()).index(key)] + else: + node = None + # check values with same keys + result[key] = yield from self.get_adapter( + old_value[key], new_value[key] + ).assign(old_value[key], node, new_value[key]) + + if to_insert: + new_code = [ + (self.context._value_to_code(k), self.context._value_to_code(v)) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context._source, + old_node, + insert_pos, + new_code, + to_insert, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + new_code = [ + (self.context._value_to_code(k), self.context._value_to_code(v)) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context._source, + old_node, + len(old_value), + new_code, + to_insert, + ) + + return result diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py new file mode 100644 index 00000000..889d3c7c --- /dev/null +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import ast +import warnings +from collections import defaultdict + +from .._align import add_x +from .._align import align +from .._change import Delete +from .._change import ListInsert +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import Item + + +class SequenceAdapter(Adapter): + node_type: type + value_type: type + + def items(self, value, node): + + assert isinstance(node, self.node_type), (node, self) + assert len(value) == len(node.elts) + + return [Item(value=v, node=n) for v, n in zip(value, node.elts)] + + def assign(self, old_value, old_node, new_value): + if old_node is not None: + assert isinstance( + old_node, ast.List if isinstance(old_value, list) else ast.Tuple + ) + + for e in old_node.elts: + if isinstance(e, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.filename, + lineno=e.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + diff = add_x(align(old_value, new_value)) + old = zip( + old_value, + old_node.elts if old_node is not None else [None] * len(old_value), + ) + new = iter(new_value) + old_position = 0 + to_insert = defaultdict(list) + result = [] + for c in diff: + if c in "mx": + old_value_element, old_node_element = next(old) + new_value_element = next(new) + v = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, old_node_element, new_value_element) + result.append(v) + old_position += 1 + elif c == "i": + new_value_element = next(new) + new_code = self.context._value_to_code(new_value_element) + result.append(new_value_element) + to_insert[old_position].append((new_code, new_value_element)) + elif c == "d": + old_value_element, old_node_element = next(old) + yield Delete( + "fix", self.context._source, old_node_element, old_value_element + ) + old_position += 1 + else: + assert False + + for position, code_values in to_insert.items(): + yield ListInsert( + "fix", self.context._source, old_node, position, *zip(*code_values) + ) + + return self.value_type(result) + + +class ListAdapter(SequenceAdapter): + node_type = ast.List + value_type = list + + +class TupleAdapter(SequenceAdapter): + node_type = ast.Tuple + value_type = tuple diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py new file mode 100644 index 00000000..93326065 --- /dev/null +++ b/src/inline_snapshot/_adapter/value_adapter.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from inline_snapshot._unmanaged import is_unmanaged +from inline_snapshot._unmanaged import update_allowed +from inline_snapshot._utils import value_to_token + +from .._change import Replace +from .adapter import Adapter + + +class ValueAdapter(Adapter): + + def assign(self, old_value, old_node, new_value): + # generic fallback + + # because IsStr() != IsStr() + + if is_unmanaged(old_value): + return old_value + + if old_node is None: + new_token = [] + else: + new_token = value_to_token(new_value) + + if not old_value == new_value: + flag = "fix" + elif ( + old_node is not None + and update_allowed(old_value) + and self.context._token_of_node(old_node) != new_token + ): + flag = "update" + else: + # equal and equal repr + return old_value + + new_code = self.context._token_to_code(new_token) + + yield Replace( + node=old_node, + file=self.context._source, + new_code=new_code, + flag=flag, + old_value=old_value, + new_value=new_value, + ) + + return new_value diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 691d4f7a..05c888f7 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any from typing import cast +from typing import DefaultDict from typing import Dict from typing import List from typing import Optional @@ -11,7 +12,7 @@ from asttokens.util import Token from executing.executing import EnhancedAST -from executing.executing import Source +from inline_snapshot._source_file import SourceFile from ._rewrite_code import ChangeRecorder from ._rewrite_code import end_of @@ -21,11 +22,11 @@ @dataclass() class Change: flag: str - source: Source + file: SourceFile @property def filename(self): - return self.source.filename + return self.file.filename def apply(self): raise NotImplementedError() @@ -76,7 +77,7 @@ class Replace(Change): def apply(self): change = ChangeRecorder.current.new_change() - range = self.source.asttokens().get_text_positions(self.node, False) + range = self.file.asttokens().get_text_positions(self.node, False) change.replace(range, self.new_code, filename=self.filename) @@ -87,40 +88,21 @@ class CallArg(Change): arg_name: Optional[str] new_code: str - old_value: Any new_value: Any - def apply(self): - change = ChangeRecorder.current.new_change() - tokens = list(self.source.asttokens().get_tokens(self.node)) - - call = self.node - tokens = list(self.source.asttokens().get_tokens(call)) - assert isinstance(call, ast.Call) - assert len(call.args) == 0 - assert len(call.keywords) == 0 - assert tokens[-2].string == "(" - assert tokens[-1].string == ")" - - assert self.arg_pos == 0 - assert self.arg_name == None - - change = ChangeRecorder.current.new_change() - change.set_tags("inline_snapshot") - change.replace( - (end_of(tokens[-2]), start_of(tokens[-1])), - self.new_code, - filename=self.filename, - ) +TokenRange = Tuple[Token, Token] -TokenRange = Tuple[Token, Token] +def brace_tokens(source, node) -> TokenRange: + first_token, *_, end_token = source.asttokens().get_tokens(node) + return first_token, end_token def generic_sequence_update( - source: Source, - parent: Union[ast.List, ast.Tuple, ast.Dict], + source: SourceFile, + parent: Union[ast.List, ast.Tuple, ast.Dict, ast.Call], + brace_tokens: TokenRange, parent_elements: List[Union[TokenRange, None]], to_insert: Dict[int, List[str]], ): @@ -128,7 +110,7 @@ def generic_sequence_update( new_code = [] deleted = False - last_token, *_, end_token = source.asttokens().get_tokens(parent) + last_token, end_token = brace_tokens is_start = True elements = 0 @@ -169,7 +151,7 @@ def generic_sequence_update( code = ", " + code if elements == 1 and isinstance(parent, ast.Tuple): - # trailing comma for tuples (1,)i + # trailing comma for tuples (1,) code += "," rec.replace( @@ -180,21 +162,23 @@ def generic_sequence_update( def apply_all(all_changes: List[Change]): - by_parent: Dict[EnhancedAST, List[Union[Delete, DictInsert, ListInsert]]] = ( - defaultdict(list) - ) - sources: Dict[EnhancedAST, Source] = {} + by_parent: Dict[ + EnhancedAST, List[Union[Delete, DictInsert, ListInsert, CallArg]] + ] = defaultdict(list) + sources: Dict[EnhancedAST, SourceFile] = {} for change in all_changes: if isinstance(change, Delete): node = cast(EnhancedAST, change.node).parent + if isinstance(node, ast.keyword): + node = node.parent by_parent[node].append(change) - sources[node] = change.source + sources[node] = change.file - elif isinstance(change, (DictInsert, ListInsert)): + elif isinstance(change, (DictInsert, ListInsert, CallArg)): node = cast(EnhancedAST, change.node) by_parent[node].append(change) - sources[node] = change.source + sources[node] = change.file else: change.apply() @@ -218,11 +202,57 @@ def list_token_range(entry): generic_sequence_update( source, parent, + brace_tokens(source, parent), [None if e in to_delete else list_token_range(e) for e in parent.elts], to_insert, ) - elif isinstance(parent, (ast.Dict)): + elif isinstance(parent, ast.Call): + to_delete = { + change.node for change in changes if isinstance(change, Delete) + } + atok = source.asttokens() + + def arg_token_range(node): + if isinstance(node.parent, ast.keyword): + node = node.parent + r = list(atok.get_tokens(node)) + return r[0], r[-1] + + braces_left = atok.next_token(list(atok.get_tokens(parent.func))[-1]) + assert braces_left.string == "(" + braces_right = list(atok.get_tokens(parent))[-1] + assert braces_right.string == ")" + + to_insert = DefaultDict(list) + + for change in changes: + if isinstance(change, CallArg): + if change.arg_name is not None: + position = ( + change.arg_pos + if change.arg_pos is not None + else len(parent.args) + len(parent.keywords) + ) + to_insert[position].append( + f"{change.arg_name} = {change.new_code}" + ) + else: + assert change.arg_pos is not None + to_insert[change.arg_pos].append(change.new_code) + + generic_sequence_update( + source, + parent, + (braces_left, braces_right), + [ + None if e in to_delete else arg_token_range(e) + for e in parent.args + [kw.value for kw in parent.keywords] + ], + to_insert, + ) + + elif isinstance(parent, ast.Dict): to_delete = { change.node for change in changes if isinstance(change, Delete) } @@ -241,6 +271,7 @@ def dict_token_range(key, value): generic_sequence_update( source, parent, + brace_tokens(source, parent), [ None if value in to_delete else dict_token_range(key, value) for key, value in zip(parent.keys, parent.values) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 9a5dcd3a..c34f5a27 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -62,7 +62,7 @@ def customize_repr(f): """Register a funtion which should be used to get the code representation of a object. - ```python + ``` python @customize_repr def _(obj: MyCustomClass): return f"MyCustomClass(attr={repr(obj.attr)})" diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 586965a4..bb7f77ce 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -1,21 +1,19 @@ import ast import copy import inspect -import tokenize -import warnings -from collections import defaultdict -from pathlib import Path from typing import Any from typing import Dict # noqa from typing import Iterator +from typing import List from typing import Set from typing import Tuple # noqa from typing import TypeVar from executing import Source +from inline_snapshot._adapter.adapter import Adapter +from inline_snapshot._source_file import SourceFile -from ._align import add_x -from ._align import align +from ._adapter import get_adapter_type from ._change import CallArg from ._change import Change from ._change import Delete @@ -24,12 +22,11 @@ from ._change import Replace from ._code_repr import code_repr from ._exceptions import UsageError -from ._format import format_code +from ._is import Is from ._sentinels import undefined from ._types import Category -from ._utils import ignore_tokens -from ._utils import normalize -from ._utils import simple_token +from ._types import Snapshot +from ._unmanaged import update_allowed from ._utils import value_to_token @@ -37,7 +34,7 @@ class NotImplementedYet(Exception): pass -snapshots = {} # type: Dict[Tuple[int, int], Snapshot] +snapshots = {} # type: Dict[Tuple[int, int], SnapshotReference] _active = False @@ -78,36 +75,40 @@ def ignore_old_value(): return _update_flags.fix or _update_flags.update -class GenericValue: +class GenericValue(Snapshot): _new_value: Any _old_value: Any _current_op = "undefined" _ast_node: ast.Expr - _source: Source + _file: SourceFile - def _token_of_node(self, node): + def get_adapter(self, value): + return get_adapter_type(value)(self._file) - return list( - normalize( - [ - simple_token(t.type, t.string) - for t in self._source.asttokens().get_tokens(node) - if t.type not in ignore_tokens - ] - ) - ) + def _re_eval(self, value): - def _format(self, text): - if self._source is None: - return text - else: - return format_code(text, Path(self._source.filename)) + def re_eval(old_value, node, value): + assert type(old_value) is type(value) - def _token_to_code(self, tokens): - return self._format(tokenize.untokenize(tokens)).strip() + adapter = self.get_adapter(old_value) + if adapter is not None and hasattr(adapter, "items"): + old_items = adapter.items(old_value, node) + new_items = adapter.items(value, node) + assert len(old_items) == len(new_items) + + for old_item, new_item in zip(old_items, new_items): + re_eval(old_item.value, old_item.node, new_item.value) + + elif isinstance(old_value, Is): + old_value.value = value.value + + else: + if update_allowed(old_value): + assert old_value == value + else: + assert not update_allowed(value) - def _value_to_code(self, value): - return self._token_to_code(value_to_token(value)) + re_eval(self._old_value, self._ast_node, value) def _ignore_old(self): return ( @@ -164,7 +165,7 @@ def __init__(self, old_value, ast_node, source): self._old_value = old_value self._new_value = undefined self._ast_node = ast_node - self._source = source + self._file = SourceFile(source) def _change(self, cls): self.__class__ = cls @@ -175,48 +176,28 @@ def _new_code(self): def _get_changes(self) -> Iterator[Change]: def handle(node, obj): - if isinstance(obj, list): - if not isinstance(node, ast.List): - return - for node_value, value in zip(node.elts, obj): - yield from handle(node_value, value) - elif isinstance(obj, tuple): - if not isinstance(node, ast.Tuple): - return - for node_value, value in zip(node.elts, obj): - yield from handle(node_value, value) - - elif isinstance(obj, dict): - if not isinstance(node, ast.Dict): - return - for value_key, node_key, node_value in zip( - obj.keys(), node.keys, node.values - ): - try: - # this is just a sanity check, dicts should be ordered - node_key = ast.literal_eval(node_key) - except Exception: - pass - else: - assert node_key == value_key - - yield from handle(node_value, obj[value_key]) - else: - if update_allowed(obj): - new_token = value_to_token(obj) - if self._token_of_node(node) != new_token: - new_code = self._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - source=self._source, - new_code=new_code, - flag="update", - old_value=self._old_value, - new_value=self._old_value, - ) - - if self._source is not None: + + adapter = self.get_adapter(obj) + if adapter is not None and hasattr(adapter, "items"): + for item in adapter.items(obj, node): + yield from handle(item.node, item.value) + return + + if update_allowed(obj): + new_token = value_to_token(obj) + if self._file._token_of_node(node) != new_token: + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=self._ast_node, + file=self._file, + new_code=new_code, + flag="update", + old_value=self._old_value, + new_value=self._old_value, + ) + + if self._file._source is not None: yield from handle(self._ast_node, self._old_value) # functions which determine the type @@ -242,19 +223,6 @@ def __getitem__(self, item): return self[item] -try: - import dirty_equals # type: ignore -except ImportError: # pragma: no cover - - def update_allowed(value): - return True - -else: - - def update_allowed(value): - return not isinstance(value, dirty_equals.DirtyEquals) - - def clone(obj): new = copy.deepcopy(obj) if not obj == new: @@ -274,230 +242,31 @@ def clone(obj): class EqValue(GenericValue): _current_op = "x == snapshot" + _changes: List[Change] def __eq__(self, other): global _missing_values if self._old_value is undefined: _missing_values += 1 - def use_valid_old_values(old_value, new_value): - - if ( - isinstance(new_value, list) - and isinstance(old_value, list) - or isinstance(new_value, tuple) - and isinstance(old_value, tuple) - ): - diff = add_x(align(old_value, new_value)) - old = iter(old_value) - new = iter(new_value) - result = [] - for c in diff: - if c in "mx": - old_value_element = next(old) - new_value_element = next(new) - result.append( - use_valid_old_values(old_value_element, new_value_element) - ) - elif c == "i": - result.append(next(new)) - elif c == "d": - pass - else: - assert False - - return type(new_value)(result) - - elif isinstance(new_value, dict) and isinstance(old_value, dict): - result = {} - - for key, new_value_element in new_value.items(): - if key in old_value: - result[key] = use_valid_old_values( - old_value[key], new_value_element - ) - else: - result[key] = new_value_element - - return result - - if new_value == old_value: - return old_value - else: - return new_value - if self._new_value is undefined: - self._new_value = use_valid_old_values(self._old_value, clone(other)) + adapter = Adapter(self._file).get_adapter(self._old_value, other) + it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) + self._changes = [] + while True: + try: + self._changes.append(next(it)) + except StopIteration as ex: + self._new_value = ex.value + break return self._visible_value() == other def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: - - assert self._old_value is not undefined - - def check(old_value, old_node, new_value): - - if ( - isinstance(old_node, ast.List) - and isinstance(new_value, list) - and isinstance(old_value, list) - or isinstance(old_node, ast.Tuple) - and isinstance(new_value, tuple) - and isinstance(old_value, tuple) - ): - for e in old_node.elts: - if isinstance(e, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self._source.filename, - lineno=e.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return - diff = add_x(align(old_value, new_value)) - old = zip(old_value, old_node.elts) - new = iter(new_value) - old_position = 0 - to_insert = defaultdict(list) - for c in diff: - if c in "mx": - old_value_element, old_node_element = next(old) - new_value_element = next(new) - yield from check( - old_value_element, old_node_element, new_value_element - ) - old_position += 1 - elif c == "i": - new_value_element = next(new) - new_code = self._value_to_code(new_value_element) - to_insert[old_position].append((new_code, new_value_element)) - elif c == "d": - old_value_element, old_node_element = next(old) - yield Delete( - "fix", self._source, old_node_element, old_value_element - ) - old_position += 1 - else: - assert False - - for position, code_values in to_insert.items(): - yield ListInsert( - "fix", self._source, old_node, position, *zip(*code_values) - ) - - return - - elif ( - isinstance(old_node, ast.Dict) - and isinstance(new_value, dict) - and isinstance(old_value, dict) - and len(old_value) == len(old_node.keys) - ): - - for key, value in zip(old_node.keys, old_node.values): - if key is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self._source.filename, - lineno=value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return - - for value, node in zip(old_value.keys(), old_node.keys): - assert node is not None - - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except: - continue - assert node_value == value - - for key, node in zip(old_value.keys(), old_node.values): - if key in new_value: - # check values with same keys - yield from check(old_value[key], node, new_value[key]) - else: - # delete entries - yield Delete("fix", self._source, node, old_value[key]) - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_value.items(): - if key not in old_value: - # add new values - to_insert.append((key, new_value_element)) - else: - if to_insert: - new_code = [ - (self._value_to_code(k), self._value_to_code(v)) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self._source, - old_node, - insert_pos, - new_code, - to_insert, - ) - to_insert = [] - insert_pos += 1 - - if to_insert: - new_code = [ - (self._value_to_code(k), self._value_to_code(v)) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self._source, - old_node, - len(old_node.values), - new_code, - to_insert, - ) - - return - - # generic fallback - - # because IsStr() != IsStr() - if type(old_value) is type(new_value) and not update_allowed(new_value): - return - - if old_node is None: - new_token = [] - else: - new_token = value_to_token(new_value) - - if not old_value == new_value: - flag = "fix" - elif ( - self._ast_node is not None - and update_allowed(old_value) - and self._token_of_node(old_node) != new_token - ): - flag = "update" - else: - return - - new_code = self._token_to_code(new_token) - - yield Replace( - node=old_node, - source=self._source, - new_code=new_code, - flag=flag, - old_value=old_value, - new_value=new_value, - ) - - yield from check(self._old_value, self._ast_node, self._new_value) + return iter(self._changes) class MinMaxValue(GenericValue): @@ -522,7 +291,7 @@ def _generic_cmp(self, other): return self.cmp(self._visible_value(), other) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: new_token = value_to_token(self._new_value) @@ -532,17 +301,17 @@ def _get_changes(self) -> Iterator[Change]: flag = "trim" elif ( self._ast_node is not None - and self._token_of_node(self._ast_node) != new_token + and self._file._token_of_node(self._ast_node) != new_token ): flag = "update" else: return - new_code = self._token_to_code(new_token) + new_code = self._file._token_to_code(new_token) yield Replace( node=self._ast_node, - source=self._source, + file=self._file, new_code=new_code, flag=flag, old_value=self._old_value, @@ -612,7 +381,7 @@ def __contains__(self, item): return item in self._old_value def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: @@ -625,19 +394,25 @@ def _get_changes(self) -> Iterator[Change]: for old_value, old_node in zip(self._old_value, elements): if old_value not in self._new_value: yield Delete( - flag="trim", source=self._source, node=old_node, old_value=old_value + flag="trim", + file=self._file, + node=old_node, + old_value=old_value, ) continue # check for update new_token = value_to_token(old_value) - if old_node is not None and self._token_of_node(old_node) != new_token: - new_code = self._token_to_code(new_token) + if ( + old_node is not None + and self._file._token_of_node(old_node) != new_token + ): + new_code = self._file._token_to_code(new_token) yield Replace( node=old_node, - source=self._source, + file=self._file, new_code=new_code, flag="update", old_value=old_value, @@ -648,10 +423,10 @@ def _get_changes(self) -> Iterator[Change]: if new_values: yield ListInsert( flag="fix", - source=self._source, + file=self._file, node=self._ast_node, position=len(self._old_value), - new_code=[self._value_to_code(v) for v in new_values], + new_code=[self._file._value_to_code(v) for v in new_values], new_values=new_values, ) @@ -665,31 +440,39 @@ def __getitem__(self, index): if self._new_value is undefined: self._new_value = {} - old_value = self._old_value - if old_value is undefined: - _missing_values += 1 - old_value = {} - - child_node = None - if self._ast_node is not None: - assert isinstance(self._ast_node, ast.Dict) - if index in old_value: - pos = list(old_value.keys()).index(index) - child_node = self._ast_node.values[pos] - if index not in self._new_value: + old_value = self._old_value + if old_value is undefined: + _missing_values += 1 + old_value = {} + + child_node = None + if self._ast_node is not None: + assert isinstance(self._ast_node, ast.Dict) + if index in old_value: + pos = list(old_value.keys()).index(index) + child_node = self._ast_node.values[pos] + self._new_value[index] = UndecidedValue( - old_value.get(index, undefined), child_node, self._source + old_value.get(index, undefined), child_node, self._file ) return self._new_value[index] + def _re_eval(self, value): + super()._re_eval(value) + + if self._new_value is not undefined and self._old_value is not undefined: + for key, s in self._new_value.items(): + if key in self._old_value: + s._re_eval(self._old_value[key]) + def _new_code(self): return ( "{" + ", ".join( [ - f"{self._value_to_code(k)}: {v._new_code()}" + f"{self._file._value_to_code(k)}: {v._new_code()}" for k, v in self._new_value.items() if not isinstance(v, UndecidedValue) ] @@ -713,7 +496,7 @@ def _get_changes(self) -> Iterator[Change]: yield from self._new_value[key]._get_changes() else: # delete entries - yield Delete("trim", self._source, node, self._old_value[key]) + yield Delete("trim", self._file, node, self._old_value[key]) to_insert = [] for key, new_value_element in self._new_value.items(): @@ -724,10 +507,10 @@ def _get_changes(self) -> Iterator[Change]: to_insert.append((key, new_value_element._new_code())) if to_insert: - new_code = [(self._value_to_code(k), v) for k, v in to_insert] + new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] yield DictInsert( "create", - self._source, + self._file, self._ast_node, len(self._old_value), new_code, @@ -784,10 +567,18 @@ def snapshot(obj: Any = undefined) -> Any: return obj frame = inspect.currentframe() + assert frame is not None frame = frame.f_back + while "mutmut" in frame.f_code.co_name: + assert frame is not None + frame = frame.f_back + assert frame is not None frame = frame.f_back + while "mutmut" in frame.f_code.co_name: + assert frame is not None + frame = frame.f_back assert frame is not None expr = Source.executing(frame) @@ -802,10 +593,12 @@ def snapshot(obj: Any = undefined) -> Any: node = expr.node if node is None: # we can run without knowing of the calling expression but we will not be able to fix code - snapshots[key] = Snapshot(obj, None) + snapshots[key] = SnapshotReference(obj, None) else: assert isinstance(node, ast.Call) - snapshots[key] = Snapshot(obj, expr) + snapshots[key] = SnapshotReference(obj, expr) + else: + snapshots[key]._re_eval(obj) return snapshots[key]._value @@ -822,7 +615,7 @@ def used_externals(tree): ] -class Snapshot: +class SnapshotReference: def __init__(self, value, expr): self._expr = expr node = expr.node.args[0] if expr is not None and expr.node.args else None @@ -840,16 +633,18 @@ def _changes(self): new_code = self._value._new_code() yield CallArg( - "create", - self._value._source, - self._expr.node if self._expr is not None else None, - 0, - None, - new_code, - self._value._old_value, - self._value._new_value, + flag="create", + file=self._value._file, + node=self._expr.node if self._expr is not None else None, + arg_pos=0, + arg_name=None, + new_code=new_code, + new_value=self._value._new_value, ) else: yield from self._value._get_changes() + + def _re_eval(self, obj): + self._value._re_eval(obj) diff --git a/src/inline_snapshot/_is.py b/src/inline_snapshot/_is.py new file mode 100644 index 00000000..1f695397 --- /dev/null +++ b/src/inline_snapshot/_is.py @@ -0,0 +1,6 @@ +class Is: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return self.value == other diff --git a/src/inline_snapshot/_rewrite_code.py b/src/inline_snapshot/_rewrite_code.py index 70eb9b6e..0cab4c56 100644 --- a/src/inline_snapshot/_rewrite_code.py +++ b/src/inline_snapshot/_rewrite_code.py @@ -98,12 +98,8 @@ def __init__(self, change_recorder): self.change_recorder._changes.append(self) self.change_id = self._next_change_id - self._tags = [] type(self)._next_change_id += 1 - def set_tags(self, *tags): - self._tags = tags - def replace(self, node, new_contend, *, filename): assert isinstance(new_contend, str) @@ -128,7 +124,7 @@ def _replace(self, filename, range, new_contend): class SourceFile: - def __init__(self, filename): + def __init__(self, filename: pathlib.Path): self.replacements: list[Replacement] = [] self.filename = filename self.source = self.filename.read_text("utf-8") diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py new file mode 100644 index 00000000..6e44b8c1 --- /dev/null +++ b/src/inline_snapshot/_source_file.py @@ -0,0 +1,54 @@ +import tokenize +from pathlib import Path + +from executing import Source +from inline_snapshot._format import format_code +from inline_snapshot._unmanaged import is_dirty_equal +from inline_snapshot._utils import normalize +from inline_snapshot._utils import simple_token +from inline_snapshot._utils import value_to_token + +from ._utils import ignore_tokens + + +class SourceFile: + _source = Source + + def __init__(self, source): + if isinstance(source, SourceFile): + self._source = source._source + else: + self._source = source + + @property + def filename(self): + return self._source.filename + + def _format(self, text): + if self._source is None: + return text + else: + return format_code(text, Path(self._source.filename)) + + def asttokens(self): + return self._source.asttokens() + + def _token_to_code(self, tokens): + return self._format(tokenize.untokenize(tokens)).strip() + + def _value_to_code(self, value): + if is_dirty_equal(value): + return "" + return self._token_to_code(value_to_token(value)) + + def _token_of_node(self, node): + + return list( + normalize( + [ + simple_token(t.type, t.string) + for t in self._source.asttokens().get_tokens(node) + if t.type not in ignore_tokens + ] + ) + ) diff --git a/src/inline_snapshot/_unmanaged.py b/src/inline_snapshot/_unmanaged.py new file mode 100644 index 00000000..be8ccf54 --- /dev/null +++ b/src/inline_snapshot/_unmanaged.py @@ -0,0 +1,24 @@ +from ._is import Is +from ._types import Snapshot + +try: + import dirty_equals # type: ignore +except ImportError: # pragma: no cover + + def is_dirty_equal(value): + return False + +else: + + def is_dirty_equal(value): + return isinstance(value, dirty_equals.DirtyEquals) or ( + isinstance(value, type) and issubclass(value, dirty_equals.DirtyEquals) + ) + + +def update_allowed(value): + return not (is_dirty_equal(value) or isinstance(value, (Is, Snapshot))) # type: ignore + + +def is_unmanaged(value): + return not update_allowed(value) diff --git a/src/inline_snapshot/pytest_plugin.py b/src/inline_snapshot/pytest_plugin.py index d97a4407..c164d4e8 100644 --- a/src/inline_snapshot/pytest_plugin.py +++ b/src/inline_snapshot/pytest_plugin.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest +from inline_snapshot.testing._example import init_env from rich import box from rich.console import Console from rich.panel import Panel @@ -54,6 +55,8 @@ def xdist_running(config): def pytest_configure(config): global flags + init_env() + _config.config = _config.read_config(config.rootpath / "pyproject.toml") if config.option.inline_snapshot is None: @@ -264,7 +267,7 @@ def report(flag, message, message_n): if sum(snapshot_changes.values()) != 0: console.print( - "\nYou can also use [b]--inline-snapshot=review[/] to approve the changes interactiv", + "\nYou can also use [b]--inline-snapshot=review[/] to approve the changes interactively", highlight=False, ) diff --git a/src/inline_snapshot/syntax_warnings.py b/src/inline_snapshot/syntax_warnings.py new file mode 100644 index 00000000..35dc21a0 --- /dev/null +++ b/src/inline_snapshot/syntax_warnings.py @@ -0,0 +1,2 @@ +class InlineSnapshotSyntaxWarning(Warning): + pass diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index a52ae0b0..7a8926f1 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -5,6 +5,7 @@ import platform import re import subprocess as sp +import traceback from argparse import ArgumentParser from pathlib import Path from tempfile import TemporaryDirectory @@ -21,6 +22,17 @@ from .._types import Snapshot +def init_env(): + import inline_snapshot._inline_snapshot as inline_snapshot + + inline_snapshot.snapshots = {} + inline_snapshot._update_flags = inline_snapshot.Flags() + inline_snapshot._active = True + external.storage = None + inline_snapshot._files_with_snapshots = set() + inline_snapshot._missing_values = 0 + + @contextlib.contextmanager def snapshot_env(): import inline_snapshot._inline_snapshot as inline_snapshot @@ -82,6 +94,14 @@ def __init__(self, files: str | dict[str, str]): self.files = files + self.dump_files() + + def dump_files(self): + for name, content in self.files.items(): + print(f"file: {name}") + print(content) + print() + def _write_files(self, dir: Path): for name, content in self.files.items(): (dir / name).write_text(content) @@ -147,6 +167,7 @@ def run_inline( try: for filename in tmp_path.glob("*.py"): globals: dict[str, Any] = {} + print("run> pytest", filename) exec( compile(filename.read_text("utf-8"), filename, "exec"), globals, @@ -157,6 +178,7 @@ def run_inline( if k.startswith("test_") and callable(v): v() except Exception as e: + print(traceback.format_exc()) assert raises == f"{type(e).__name__}:\n" + str(e) finally: diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py new file mode 100644 index 00000000..4d81b7df --- /dev/null +++ b/tests/adapter/test_dataclass.py @@ -0,0 +1,224 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + +from tests.warns import warns + + +def test_unmanaged(): + + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + assert A(a=2,b=4) == snapshot(A(a=1,b=Is(1))), "not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + assert A(a=2,b=4) == snapshot(A(a=2,b=Is(1))), "not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_reeval(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + for c in "ab": + assert A(a=1,b=c) == snapshot(A(a=2,b=Is(c))) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + for c in "ab": + assert A(a=1,b=c) == snapshot(A(a=1,b=Is(c))) +""" + } + ), + ) + + +def test_default_value(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:int=field(default_factory=list) + +def test_something(): + for c in "ab": + assert A(a=c) == snapshot(A(a=Is(c),b=2,c=[])) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:int=field(default_factory=list) + +def test_something(): + for c in "ab": + assert A(a=c) == snapshot(A(a=Is(c))) +""" + } + ), + ) + + +def test_disabled(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(a=5)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_starred_warns(): + with warns( + snapshot( + [ + ( + 10, + "InlineSnapshotSyntaxWarning: star-expressions are not supported inside snapshots", + ) + ] + ), + include_line=True, + ): + Example( + """ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(**{"a":5})),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_add_argument(): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=0 + b:int=0 + c:int=0 + +def test_something(): + assert A(a=3,b=3,c=3) == snapshot(A(b=3)),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=0 + b:int=0 + c:int=0 + +def test_something(): + assert A(a=3,b=3,c=3) == snapshot(A(a = 3, b=3, c = 3)),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) diff --git a/tests/adapter/test_general.py b/tests/adapter/test_general.py new file mode 100644 index 00000000..abbc5791 --- /dev/null +++ b/tests/adapter/test_general.py @@ -0,0 +1,31 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +def test_adapter_mismatch(): + + Example( + """\ +from inline_snapshot import snapshot + + +def test_thing(): + assert [1,2] == snapshot({1:2}) + + """ + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + + +def test_thing(): + assert [1,2] == snapshot([1, 2]) + + \ +""" + } + ), + ) diff --git a/tests/conftest.py b/tests/conftest.py index 5ef37fc2..a91e1d44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ from inline_snapshot._format import format_code from inline_snapshot._inline_snapshot import Flags from inline_snapshot._rewrite_code import ChangeRecorder +from inline_snapshot._types import Category from inline_snapshot.testing._example import snapshot_env pytest_plugins = "pytester" @@ -53,7 +54,7 @@ def w(source_code, *, flags="", reported_flags=None, number=1): @pytest.fixture() -def source(tmp_path): +def source(tmp_path: Path): filecount = 1 @dataclass @@ -64,8 +65,8 @@ class Source: number_snapshots: int = 0 number_changes: int = 0 - def run(self, *flags): - flags = Flags({*flags}) + def run(self, *flags_arg: Category): + flags = Flags({*flags_arg}) nonlocal filecount filename: Path = tmp_path / f"test_{filecount}.py" @@ -288,7 +289,10 @@ def format(self): ) def pyproject(self, source): - (pytester.path / "pyproject.toml").write_text(source, "utf-8") + self.write_file("pyproject.toml", source) + + def write_file(self, filename, content): + (pytester.path / filename).write_text(content, "utf-8") def storage(self): dir = pytester.path / ".inline-snapshot" / "external" diff --git a/tests/test_change.py b/tests/test_change.py new file mode 100644 index 00000000..cbe82589 --- /dev/null +++ b/tests/test_change.py @@ -0,0 +1,90 @@ +import ast + +import pytest +from executing import Source +from inline_snapshot._change import apply_all +from inline_snapshot._change import CallArg +from inline_snapshot._change import Delete +from inline_snapshot._change import Replace +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot._rewrite_code import ChangeRecorder +from inline_snapshot._source_file import SourceFile + + +@pytest.fixture +def check_change(tmp_path): + i = 0 + + def w(source, changes, new_code): + nonlocal i + + filename = tmp_path / f"test_{i}.py" + i += 1 + + filename.write_text(source) + print(f"\ntest: {source}") + + source = Source.for_filename(filename) + module = source.tree + context = SourceFile(source) + + call = module.body[0].value + assert isinstance(call, ast.Call) + + with ChangeRecorder().activate() as cr: + apply_all(changes(context, call)) + + cr.virtual_write() + + cr.dump() + + assert list(cr.files())[0].source == new_code + + return w + + +def test_change_function_args(check_change): + + check_change( + "f(a,b=2)", + lambda source, call: [ + Replace( + flag="fix", + file=source, + node=call.args[0], + new_code="22", + old_value=0, + new_value=0, + ) + ], + snapshot("f(22,b=2)"), + ) + + check_change( + "f(a,b=2)", + lambda source, call: [ + Delete( + flag="fix", + file=source, + node=call.args[0], + old_value=0, + ) + ], + snapshot("f(b=2)"), + ) + + check_change( + "f(a,b=2)", + lambda source, call: [ + CallArg( + flag="fix", + file=source, + node=call, + arg_pos=0, + arg_name=None, + new_code="22", + new_value=22, + ) + ], + snapshot("f(22, a,b=2)"), + ) diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 1ee4f9ea..bdf51b7e 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -368,3 +368,37 @@ def __repr__(self): return "FakeTuple()" assert code_repr(FakeTuple()) == snapshot("FakeTuple()") + + +def test_invalid_repr(check_update): + assert ( + check_update( + """\ +class Thing: + def __repr__(self): + return "+++" + + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + +assert Thing() == snapshot() +""", + flags="create", + ) + == snapshot( + """\ +class Thing: + def __repr__(self): + return "+++" + + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + +assert Thing() == snapshot(HasRepr(Thing, "+++")) +""" + ) + ) diff --git a/tests/test_docs.py b/tests/test_docs.py index 14a4046f..d1543afd 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,11 +1,233 @@ +import itertools import platform import re import sys import textwrap +from collections import defaultdict +from dataclasses import dataclass from pathlib import Path +from typing import Optional import inline_snapshot._inline_snapshot import pytest +from inline_snapshot import snapshot +from inline_snapshot.extra import raises + + +@dataclass +class Block: + code: str + code_header: Optional[str] + block_options: str + line: int + + +def map_code_blocks(file, func, fix=False): + + block_start = re.compile("( *)``` *python(.*)") + block_end = re.compile("```.*") + + header = re.compile("") + + current_code = file.read_text("utf-8") + new_lines = [] + block_lines = [] + options = set() + is_block = False + code = None + indent = "" + block_start_linenum = None + block_options = None + code_header = None + header_line = "" + + for linenumber, line in enumerate(current_code.splitlines(), start=1): + m = block_start.fullmatch(line) + if m and not is_block: + # ``` python + block_start_linenum = linenumber + indent = m[1] + block_options = m[2] + block_lines = [] + is_block = True + continue + + if block_end.fullmatch(line.strip()) and is_block: + # ``` + is_block = False + + code = "\n".join(block_lines) + "\n" + code = textwrap.dedent(code) + if file.suffix == ".py": + code = code.replace("\\\\", "\\") + + try: + new_block = func( + Block( + code=code, + code_header=code_header, + block_options=block_options, + line=block_start_linenum, + ) + ) + except Exception: + print(f"error at block at line {block_start_linenum}") + print(f"{code_header=}") + print(f"{block_options=}") + print(code) + raise + + if new_block.code_header is not None: + new_lines.append(f"{indent}") + + new_lines.append( + f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" + ) + + new_code = new_block.code.rstrip() + if file.suffix == ".py": + new_code = new_code.replace("\\", "\\\\") + new_code = textwrap.indent(new_code, indent) + + new_lines.append(new_code) + + new_lines.append(f"{indent}```") + + header_line = "" + code_header = None + + continue + + if is_block: + block_lines.append(line) + continue + + m = header.fullmatch(line.strip()) + if m: + # comment + header_line = line + code_header = m[1].strip() + continue + else: + if header_line: + new_lines.append(header_line) + code_header = None + header_line = "" + + new_lines.append(line) + + new_code = "\n".join(new_lines) + "\n" + + if fix: + file.write_text(new_code) + else: + assert current_code.splitlines() == new_code.splitlines() + assert current_code == new_code + + +def test_map_code_blocks(tmp_path): + + file = tmp_path / "example.md" + + def test_doc( + markdown_code, + handle_block=lambda block: exec(block.code), + blocks=[], + exception="", + new_markdown_code=None, + ): + + file.write_text(markdown_code) + + recorded_blocks = [] + + with raises(exception): + + def test_block(block): + handle_block(block) + recorded_blocks.append(block) + return block + + map_code_blocks(file, test_block, True) + assert recorded_blocks == blocks + map_code_blocks(file, test_block, False) + + recorded_markdown_code = file.read_text() + if recorded_markdown_code != markdown_code: + assert new_markdown_code == recorded_markdown_code + else: + assert new_markdown_code == None + + test_doc( + """ +``` python +1 / 0 +``` +""", + exception=snapshot("ZeroDivisionError: division by zero"), + ) + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +text + +``` python hl_lines="1 2 3" +print(1 - 1) +``` +text +""", + blocks=snapshot( + [ + Block( + code="print(1 + 1)\n", code_header=None, block_options="", line=2 + ), + Block( + code="print(1 - 1)\n", + code_header="inline-snapshot: create test", + block_options=' hl_lines="1 2 3"', + line=7, + ), + ] + ), + ) + + def change_block(block): + block.code = "# removed" + block.code_header = "header" + block.block_options = "option a b c" + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +""", + handle_block=change_block, + blocks=snapshot( + [ + Block( + code="# removed", + code_header="header", + block_options="option a b c", + line=2, + ) + ] + ), + new_markdown_code=snapshot( + """\ +text + +``` python option a b c +# removed +``` +""" + ), + ) @pytest.mark.skipif( @@ -14,7 +236,7 @@ ) @pytest.mark.skipif( sys.version_info[:2] != (3, 12), - reason="\\r in stdout can cause problems in snapshot strings", + reason="there is no reason to test the doc with different python versions", ) @pytest.mark.parametrize( "file", @@ -36,19 +258,7 @@ def test_docs(project, file, subtests): * `outcome-passed=2` to check for the pytest test outcome """ - block_start = re.compile("( *)``` *python.*") - block_end = re.compile("```.*") - - header = re.compile("") - - text = file.read_text("utf-8") - new_lines = [] - block_lines = [] - options = set() - is_block = False - code = None - indent = "" - first_block = True + last_code = None project.pyproject( """ @@ -57,132 +267,104 @@ def test_docs(project, file, subtests): """ ) - for linenumber, line in enumerate(text.splitlines(), start=1): - m = block_start.fullmatch(line) - if m and is_block == True: - block_start_line = line - indent = m[1] - block_lines = [] - continue + extra_files = defaultdict(list) - if block_end.fullmatch(line.strip()) and is_block: - with subtests.test(line=linenumber): - is_block = False + def test_block(block: Block): + if block.code_header is None: + return block - last_code = code - code = "\n".join(block_lines) + "\n" - code = textwrap.dedent(code) - if file.suffix == ".py": - code = code.replace("\\\\", "\\") + if block.code_header.startswith("inline-snapshot-lib:"): + extra_files[block.code_header.split()[1]].append(block.code) + return block - flags = options & {"fix", "update", "create", "trim"} + if block.code_header.startswith("todo-inline-snapshot:"): + return block + assert False - args = ["--inline-snapshot", ",".join(flags)] if flags else [] + nonlocal last_code + with subtests.test(line=block.line): - if flags and "first_block" not in options: - project.setup(last_code) - else: - project.setup(code) + code = block.code - result = project.run(*args) + options = set(block.code_header.split()) - print("flags:", flags) + flags = options & {"fix", "update", "create", "trim"} - new_code = code - if flags: - new_code = project.source + args = ["--inline-snapshot", ",".join(flags)] if flags else [] - if "show_error" in options: - new_code = new_code.split("# Error:")[0] - new_code += "# Error:\n" + textwrap.indent( - result.errorLines(), "# " - ) + if flags and "first_block" not in options: + project.setup(last_code) + else: + project.setup(code) - print("new code:") - print(new_code) - print("expected code:") - print(code) + if extra_files: + all_files = [ + [(key, file) for file in files] + for key, files in extra_files.items() + ] + for files in itertools.product(*all_files): + for filename, content in files: + project.write_file(filename, content) + result = project.run(*args) - if ( - inline_snapshot._inline_snapshot._update_flags.fix - ): # pragma: no cover - flags_str = " ".join( - sorted(flags) - + sorted(options & {"first_block", "show_error"}) - + [ - f"outcome-{k}={v}" - for k, v in result.parseoutcomes().items() - if k in ("failed", "errors", "passed") - ] - ) - header_line = f"{indent}" + else: - new_lines.append(header_line) + result = project.run(*args) - from inline_snapshot._align import align - - linenum = 1 - hl_lines = "" - if last_code is not None and "first_block" not in options: - changed_lines = [] - alignment = align(last_code.split("\n"), new_code.split("\n")) - for c in alignment: - if c == "d": - continue - elif c == "m": - linenum += 1 - else: - changed_lines.append(str(linenum)) - linenum += 1 - if changed_lines: - hl_lines = f' hl_lines="{" ".join(changed_lines)}"' + print("flags:", flags, repr(block.block_options)) + + new_code = code + if flags: + new_code = project.source + + if "show_error" in options: + new_code = new_code.split("# Error:")[0] + new_code += "# Error:\n" + textwrap.indent(result.errorLines(), "# ") + + print("new code:") + print(new_code) + print("expected code:") + print(code) + + block.code_header = "inline-snapshot: " + " ".join( + sorted(flags) + + sorted(options & {"first_block", "show_error"}) + + [ + f"outcome-{k}={v}" + for k, v in result.parseoutcomes().items() + if k in ("failed", "errors", "passed") + ] + ) + + from inline_snapshot._align import align + + linenum = 1 + hl_lines = "" + if last_code is not None and "first_block" not in options: + changed_lines = [] + alignment = align(last_code.split("\n"), new_code.split("\n")) + for c in alignment: + if c == "d": + continue + elif c == "m": + linenum += 1 else: - assert False, "no lines changed" - - new_lines.append(f"{indent}``` python{hl_lines}") - - if ( - inline_snapshot._inline_snapshot._update_flags.fix - ): # pragma: no cover - new_code = new_code.rstrip("\n") - if file.suffix == ".py": - new_code = new_code.replace("\\", "\\\\") - new_code = textwrap.indent(new_code, indent) - - new_lines.append(new_code) + changed_lines.append(str(linenum)) + linenum += 1 + if changed_lines: + hl_lines = f'hl_lines="{" ".join(changed_lines)}"' else: - new_lines += block_lines + assert False, "no lines changed" + block.block_options = hl_lines - new_lines.append(line) + block.code = new_code - if not inline_snapshot._inline_snapshot._update_flags.fix: - if flags: - assert result.ret == 0 - else: - assert { - f"outcome-{k}={v}" - for k, v in result.parseoutcomes().items() - if k in ("failed", "errors", "passed") - } == {flag for flag in options if flag.startswith("outcome-")} - assert code == new_code - else: # pragma: no cover - pass - - continue - - m = header.fullmatch(line.strip()) - if m: - options = set(m.group(1).split()) - if first_block: - options.add("first_block") - first_block = False - header_line = line - is_block = True + if flags: + assert result.ret == 0 - if is_block: - block_lines.append(line) - else: - new_lines.append(line) + last_code = code + return block - if inline_snapshot._inline_snapshot._update_flags.fix: # pragma: no cover - file.write_text("\n".join(new_lines) + "\n", "utf-8") + map_code_blocks( + file, test_block, inline_snapshot._inline_snapshot._update_flags.fix + ) diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index ef7ffbde..4bd55a35 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -1,4 +1,3 @@ -import ast import contextlib import itertools import warnings @@ -8,12 +7,9 @@ from typing import Union import pytest -from hypothesis import given -from hypothesis.strategies import text from inline_snapshot import _inline_snapshot from inline_snapshot import snapshot from inline_snapshot._inline_snapshot import Flags -from inline_snapshot._utils import triple_quote from inline_snapshot.testing import Example from inline_snapshot.testing._example import snapshot_env @@ -582,194 +578,6 @@ def test_plain(check_update, executing_used): assert check_update("s = snapshot()", flags="") == snapshot("s = snapshot()") -def test_string_update(check_update): - # black --preview wraps strings to keep the line length. - # string concatenation should produce updates. - assert ( - check_update( - 'assert "ab" == snapshot("a" "b")', reported_flags="", flags="update" - ) - == 'assert "ab" == snapshot("a" "b")' - ) - - assert ( - check_update( - 'assert "ab" == snapshot("a"\n "b")', reported_flags="", flags="update" - ) - == 'assert "ab" == snapshot("a"\n "b")' - ) - - assert check_update( - 'assert "ab\\nc" == snapshot("a"\n "b\\nc")', flags="update" - ) == snapshot( - '''\ -assert "ab\\nc" == snapshot("""\\ -ab -c\\ -""")\ -''' - ) - - assert ( - check_update( - 'assert b"ab" == snapshot(b"a"\n b"b")', reported_flags="", flags="update" - ) - == 'assert b"ab" == snapshot(b"a"\n b"b")' - ) - - -def test_string_newline(check_update): - assert check_update('s = snapshot("a\\nb")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ -a -b\\ -""")\ -''' - ) - - assert check_update('s = snapshot("a\\"\\"\\"\\nb")', flags="update") == snapshot( - """\ -s = snapshot('''\\ -a\"\"\" -b\\ -''')\ -""" - ) - - assert check_update( - 's = snapshot("a\\"\\"\\"\\n\\\'\\\'\\\'b")', flags="update" - ) == snapshot( - '''\ -s = snapshot("""\\ -a\\"\\"\\" -\'\'\'b\\ -""")\ -''' - ) - - assert check_update('s = snapshot(b"a\\nb")') == snapshot('s = snapshot(b"a\\nb")') - - assert check_update('s = snapshot("\\n\\\'")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -'\\ -""")\ -''' - ) - - assert check_update('s = snapshot("\\n\\"")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -"\\ -""")\ -''' - ) - - assert check_update("s = snapshot(\"'''\\n\\\"\")", flags="update") == snapshot( - '''\ -s = snapshot("""\\ -\'\'\' -\\"\\ -""")\ -''' - ) - - assert check_update('s = snapshot("\\n\b")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -\\x08\\ -""")\ -''' - ) - - -def test_string_quote_choice(check_update): - assert check_update( - "s = snapshot(\" \\'\\'\\' \\'\\'\\' \\\"\\\"\\\"\\nother_line\")", - flags="update", - ) == snapshot( - '''\ -s = snapshot("""\\ - \'\'\' \'\'\' \\"\\"\\" -other_line\\ -""")\ -''' - ) - - assert check_update( - 's = snapshot(" \\\'\\\'\\\' \\"\\"\\" \\"\\"\\"\\nother_line")', flags="update" - ) == snapshot( - """\ -s = snapshot('''\\ - \\'\\'\\' \"\"\" \"\"\" -other_line\\ -''')\ -""" - ) - - assert check_update('s = snapshot("\\n\\"")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -"\\ -""")\ -''' - ) - - assert check_update( - "s=snapshot('\\n')", flags="update", reported_flags="" - ) == snapshot("s=snapshot('\\n')") - assert check_update( - "s=snapshot('abc\\n')", flags="update", reported_flags="" - ) == snapshot("s=snapshot('abc\\n')") - assert check_update("s=snapshot('abc\\nabc')", flags="update") == snapshot( - '''\ -s=snapshot("""\\ -abc -abc\\ -""")\ -''' - ) - assert check_update("s=snapshot('\\nabc')", flags="update") == snapshot( - '''\ -s=snapshot("""\\ - -abc\\ -""")\ -''' - ) - assert check_update("s=snapshot('a\\na\\n')", flags="update") == snapshot( - '''\ -s=snapshot("""\\ -a -a -""")\ -''' - ) - - assert ( - check_update( - '''\ -s=snapshot("""\\ -a -""")\ -''', - flags="update", - ) - == snapshot('s=snapshot("a\\n")') - ) - - -@given(s=text()) -def test_string_convert(s): - print(s) - assert ast.literal_eval(triple_quote(s)) == s - - def test_flags_repr(): assert repr(Flags({"update"})) == "Flags({'update'})" @@ -832,40 +640,6 @@ def test_type_error(check_update): assert test1 == test2 -def test_invalid_repr(check_update): - assert ( - check_update( - """\ -class Thing: - def __repr__(self): - return "+++" - - def __eq__(self,other): - if not isinstance(other,Thing): - return NotImplemented - return True - -assert Thing() == snapshot() -""", - flags="create", - ) - == snapshot( - """\ -class Thing: - def __repr__(self): - return "+++" - - def __eq__(self,other): - if not isinstance(other,Thing): - return NotImplemented - return True - -assert Thing() == snapshot(HasRepr(Thing, "+++")) -""" - ) - ) - - def test_sub_snapshot_create(check_update): assert ( @@ -1094,7 +868,7 @@ def test_dirty_equals_in_unused_snapshot() -> None: Example( """ from dirty_equals import IsStr -from inline_snapshot import snapshot +from inline_snapshot import snapshot,Is snapshot([IsStr(),3]) snapshot((IsStr(),3)) @@ -1104,7 +878,7 @@ def test_dirty_equals_in_unused_snapshot() -> None: t=(1,2) d={1:2} l=[1,2] -snapshot([t,d,l]) +snapshot([Is(t),Is(d),Is(l)]) """ ).run_inline( ["--inline-snapshot=fix"], @@ -1152,7 +926,7 @@ def test_starred_warns_list(): """ from inline_snapshot import snapshot -assert [5] == snapshot([*[4]]) +assert [5] == snapshot([*[5]]) """ ).run_inline(["--inline-snapshot=fix"]) @@ -1173,7 +947,7 @@ def test_starred_warns_dict(): """ from inline_snapshot import snapshot -assert {1:3} == snapshot({**{1:2}}) +assert {1:3} == snapshot({**{1:3}}) """ ).run_inline(["--inline-snapshot=fix"]) @@ -1195,35 +969,62 @@ class Now(DirtyEquals): def equals(self, other): return other == now - assert now == snapshot(Now()) + assert 5 == snapshot(Now()) now = 6 - assert 5 == snapshot(Now()) + assert 5 == snapshot(Now()), "different time" """ ).run_inline( ["--inline-snapshot=fix"], - changed_files=snapshot( - { - "test_something.py": """\ + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +different time\ +""" + ), + ) -from dirty_equals import DirtyEquals -from inline_snapshot import snapshot +def test_is(): -def test_time(): + Example( + """ +from inline_snapshot import snapshot,Is - now = 5 +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hi",Is(i)]) + assert ["hello",i] == snapshot({1:["hi",Is(i)]})[i] +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ - class Now(DirtyEquals): - def equals(self, other): - return other == now +from inline_snapshot import snapshot,Is - assert now == snapshot(Now()) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hi",Is(i)]) + assert ["hello",i] == snapshot({1:["hi",Is(i)], 0: ["hello", 0], 2: ["hello", 2]})[i] +""" + } + ), + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ - now = 6 +from inline_snapshot import snapshot,Is - assert 5 == snapshot(5) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hello",Is(i)]) + assert ["hello",i] == snapshot({1:["hello",Is(i)], 0: ["hello", 0], 2: ["hello", 2]})[i] """ } ), diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 14413c96..8973fd5f 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -23,7 +23,7 @@ def test_a(): assert result.report == snapshot( """\ Error: one snapshot is missing a value (--inline-snapshot=create) -You can also use --inline-snapshot=review to approve the changes interactiv +You can also use --inline-snapshot=review to approve the changes interactively """ ) @@ -70,7 +70,7 @@ def test_a(): assert result.report == snapshot( """\ Error: one snapshot has incorrect values (--inline-snapshot=fix) -You can also use --inline-snapshot=review to approve the changes interactiv +You can also use --inline-snapshot=review to approve the changes interactively """ ) @@ -157,7 +157,7 @@ def test_a(): assert result.report == snapshot( """\ Info: one snapshot can be trimmed (--inline-snapshot=trim) -You can also use --inline-snapshot=review to approve the changes interactiv +You can also use --inline-snapshot=review to approve the changes interactively """ ) @@ -205,7 +205,7 @@ def test_a(): """\ Error: one snapshot has incorrect values (--inline-snapshot=fix) Info: one snapshot can be trimmed (--inline-snapshot=trim) -You can also use --inline-snapshot=review to approve the changes interactiv +You can also use --inline-snapshot=review to approve the changes interactively """ ) @@ -495,7 +495,7 @@ def test_sub_snapshot(): your snapshot is missing one value run pytest with --inline-snapshot=create to create it ======================================================================= inline snapshot ======================================================================== Error: one snapshot is missing a value (--inline-snapshot=create) -You can also use --inline-snapshot=review to approve the changes interactiv +You can also use --inline-snapshot=review to approve the changes interactively =================================================================== short test summary info ==================================================================== ERROR test_file.py::test_sub_snapshot - Failed: your snapshot is missing one value run pytest with --inline-snapshot=create to create it ================================================================== 1 passed, 1 error in