Skip to content

Commit

Permalink
Remove duplicates in _AndList/_OrList filters.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanslenders committed Feb 28, 2023
1 parent 0b4f04e commit 56abcd9
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 14 deletions.
76 changes: 63 additions & 13 deletions src/prompt_toolkit/filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class Filter(metaclass=ABCMeta):
"""

def __init__(self) -> None:
self._and_cache: dict[Filter, _AndList] = {}
self._or_cache: dict[Filter, _OrList] = {}
self._and_cache: dict[Filter, Filter] = {}
self._or_cache: dict[Filter, Filter] = {}
self._invert_result: Filter | None = None

@abstractmethod
Expand All @@ -40,7 +40,7 @@ def __and__(self, other: Filter) -> Filter:
if other in self._and_cache:
return self._and_cache[other]

result = _AndList([self, other])
result = _AndList.create([self, other])
self._and_cache[other] = result
return result

Expand All @@ -58,7 +58,7 @@ def __or__(self, other: Filter) -> Filter:
if other in self._or_cache:
return self._or_cache[other]

result = _OrList([self, other])
result = _OrList.create([self, other])
self._or_cache[other] = result
return result

Expand Down Expand Up @@ -86,20 +86,49 @@ def __bool__(self) -> None:
)


def _remove_duplicates(filters: list[Filter]) -> list[Filter]:
result = []
for f in filters:
if f not in result:
result.append(f)
return result


class _AndList(Filter):
"""
Result of &-operation between several filters.
"""

def __init__(self, filters: Iterable[Filter]) -> None:
def __init__(self, filters: list[Filter]) -> None:
super().__init__()
self.filters: list[Filter] = []
self.filters = filters

@classmethod
def create(cls, filters: Iterable[Filter]) -> Filter:
"""
Create a new filter by applying an `&` operator between them.
If there's only one unique filter in the given iterable, it will return
that one filter instead of an `_AndList`.
"""
filters_2: list[Filter] = []

for f in filters:
if isinstance(f, _AndList): # Turn nested _AndLists into one.
self.filters.extend(f.filters)
filters_2.extend(f.filters)
else:
self.filters.append(f)
filters_2.append(f)

# Remove duplicates. This could speed up execution, and doesn't make a
# difference for the evaluation.
filters = _remove_duplicates(filters_2)

# If only one filter is left, return that without wrapping into an
# `_AndList`.
if len(filters) == 1:
return filters[0]

return cls(filters)

def __call__(self) -> bool:
return all(f() for f in self.filters)
Expand All @@ -113,15 +142,36 @@ class _OrList(Filter):
Result of |-operation between several filters.
"""

def __init__(self, filters: Iterable[Filter]) -> None:
def __init__(self, filters: list[Filter]) -> None:
super().__init__()
self.filters: list[Filter] = []
self.filters = filters

@classmethod
def create(cls, filters: Iterable[Filter]) -> Filter:
"""
Create a new filter by applying an `|` operator between them.
If there's only one unique filter in the given iterable, it will return
that one filter instead of an `_OrList`.
"""
filters_2: list[Filter] = []

for f in filters:
if isinstance(f, _OrList): # Turn nested _OrLists into one.
self.filters.extend(f.filters)
if isinstance(f, _OrList): # Turn nested _AndLists into one.
filters_2.extend(f.filters)
else:
self.filters.append(f)
filters_2.append(f)

# Remove duplicates. This could speed up execution, and doesn't make a
# difference for the evaluation.
filters = _remove_duplicates(filters_2)

# If only one filter is left, return that without wrapping into an
# `_AndList`.
if len(filters) == 1:
return filters[0]

return cls(filters)

def __call__(self) -> bool:
return any(f() for f in self.filters)
Expand Down
50 changes: 49 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import pytest
import gc

import pytest

from prompt_toolkit.filters import Always, Condition, Filter, Never, to_filter
from prompt_toolkit.filters.base import _AndList, _OrList


def test_never():
Expand Down Expand Up @@ -44,6 +46,32 @@ def test_and():
assert c3() == (a and b)


def test_nested_and():
for a in (True, False):
for b in (True, False):
for c in (True, False):
c1 = Condition(lambda: a)
c2 = Condition(lambda: b)
c3 = Condition(lambda: c)
c4 = (c1 & c2) & c3

assert isinstance(c4, Filter)
assert c4() == (a and b and c)


def test_nested_or():
for a in (True, False):
for b in (True, False):
for c in (True, False):
c1 = Condition(lambda: a)
c2 = Condition(lambda: b)
c3 = Condition(lambda: c)
c4 = (c1 | c2) | c3

assert isinstance(c4, Filter)
assert c4() == (a or b or c)


def test_to_filter():
f1 = to_filter(True)
f2 = to_filter(False)
Expand Down Expand Up @@ -75,6 +103,7 @@ def test_filter_cache_regression_1():
y = (cond & cond) & cond
assert x == y


def test_filter_cache_regression_2():
cond1 = Condition(lambda: True)
cond2 = Condition(lambda: True)
Expand All @@ -83,3 +112,22 @@ def test_filter_cache_regression_2():
x = (cond1 & cond2) & cond3
y = (cond1 & cond2) & cond3
assert x == y


def test_filter_remove_duplicates():
cond1 = Condition(lambda: True)
cond2 = Condition(lambda: True)

# When a condition is appended to itself using an `&` or `|` operator, it
# should not be present twice. Having it twice in the `_AndList` or
# `_OrList` will make them more expensive to evaluate.

assert isinstance(cond1 & cond1, Condition)
assert isinstance(cond1 & cond1 & cond1, Condition)
assert isinstance(cond1 & cond1 & cond2, _AndList)
assert len((cond1 & cond1 & cond2).filters) == 2

assert isinstance(cond1 | cond1, Condition)
assert isinstance(cond1 | cond1 | cond1, Condition)
assert isinstance(cond1 | cond1 | cond2, _OrList)
assert len((cond1 | cond1 | cond2).filters) == 2

0 comments on commit 56abcd9

Please sign in to comment.