From e13cf01b3e59b2f62ac8aa0691e1a4a068a6b6ca Mon Sep 17 00:00:00 2001 From: Harvir Sahota Date: Thu, 28 Mar 2024 21:34:39 -0700 Subject: [PATCH 1/5] Implement special sort order and binary search --- python/selfie-lib/selfie_lib/ArrayMap.py | 132 +++++++++++++---------- 1 file changed, 78 insertions(+), 54 deletions(-) diff --git a/python/selfie-lib/selfie_lib/ArrayMap.py b/python/selfie-lib/selfie_lib/ArrayMap.py index 49d317b1..ec1820d7 100644 --- a/python/selfie-lib/selfie_lib/ArrayMap.py +++ b/python/selfie-lib/selfie_lib/ArrayMap.py @@ -1,10 +1,33 @@ +from abc import ABC, abstractmethod from collections.abc import Set, Iterator, Mapping -from typing import List, TypeVar, Union -from abc import abstractmethod, ABC +from typing import List, TypeVar, Union, Any, Tuple +import bisect + + +class Comparable: + def __lt__(self, other: Any) -> bool: + return NotImplemented + + def __le__(self, other: Any) -> bool: + return NotImplemented + + def __gt__(self, other: Any) -> bool: + return NotImplemented + + def __ge__(self, other: Any) -> bool: + return NotImplemented + T = TypeVar("T") V = TypeVar("V") -K = TypeVar("K") +K = TypeVar("K", bound=Comparable) + + +def string_slash_first_comparator(a: Any, b: Any) -> int: + """Special comparator for strings where '/' is considered the lowest.""" + if isinstance(a, str) and isinstance(b, str): + return (a.replace("/", "\0"), a) < (b.replace("/", "\0"), b) + return (a < b) - (a > b) class ListBackedSet(Set[T], ABC): @@ -15,107 +38,108 @@ def __len__(self) -> int: ... def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ... def __contains__(self, item: object) -> bool: - for i in range(len(self)): - if self[i] == item: - return True - return False + try: + index = self.__binary_search(item) + except ValueError: + return False + return index >= 0 + + @abstractmethod + def __binary_search(self, item: Any) -> int: ... class ArraySet(ListBackedSet[K]): __data: List[K] - def __init__(self, data: List[K]): - raise NotImplementedError("Use ArraySet.empty() instead") + def __init__(self): + raise NotImplementedError("Use ArraySet.empty() or other class methods instead") @classmethod def __create(cls, data: List[K]) -> "ArraySet[K]": - # Create a new instance without calling __init__ instance = super().__new__(cls) instance.__data = data return instance - def __iter__(self) -> Iterator[K]: - return iter(self.__data) - @classmethod def empty(cls) -> "ArraySet[K]": if not hasattr(cls, "__EMPTY"): - cls.__EMPTY = cls([]) + cls.__EMPTY = cls.__create([]) return cls.__EMPTY def __len__(self) -> int: return len(self.__data) def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]: - if isinstance(index, int): - return self.__data[index] - elif isinstance(index, slice): - return self.__data[index] - else: - raise TypeError("Invalid argument type.") + return self.__data[index] + + def __binary_search(self, item: K) -> int: + if isinstance(item, str): + key = lambda x: x.replace("/", "\0") + return ( + bisect.bisect_left(self.__data, item, key=key) - 1 + if item in self.__data + else -1 + ) + return bisect.bisect_left(self.__data, item) - 1 if item in self.__data else -1 def plusOrThis(self, element: K) -> "ArraySet[K]": - # TODO: use binary search, and also special sort order for strings - if element in self.__data: + index = self.__binary_search(element) + if index >= 0: return self - else: - new_data = self.__data[:] - new_data.append(element) - new_data.sort() # type: ignore[reportOperatorIssue] - return ArraySet.__create(new_data) + new_data = self.__data[:] + bisect.insort_left(new_data, element) + return ArraySet.__create(new_data) class ArrayMap(Mapping[K, V]): - def __init__(self, data: list): - # TODO: hide this constructor as done in ArraySet - self.__data = data + __data: List[Tuple[K, V]] + + def __init__(self): + raise NotImplementedError("Use ArrayMap.empty() or other class methods instead") + + @classmethod + def __create(cls, data: List[Tuple[K, V]]) -> "ArrayMap[K, V]": + instance = super().__new__(cls) + instance.__data = data + return instance @classmethod def empty(cls) -> "ArrayMap[K, V]": if not hasattr(cls, "__EMPTY"): - cls.__EMPTY = cls([]) + cls.__EMPTY = cls.__create([]) return cls.__EMPTY def __getitem__(self, key: K) -> V: index = self.__binary_search_key(key) if index >= 0: - return self.__data[2 * index + 1] + return self.__data[index][1] raise KeyError(key) def __iter__(self) -> Iterator[K]: - return (self.__data[i] for i in range(0, len(self.__data), 2)) + return (key for key, _ in self.__data) def __len__(self) -> int: - return len(self.__data) // 2 + return len(self.__data) def __binary_search_key(self, key: K) -> int: - # TODO: special sort order for strings - low, high = 0, (len(self.__data) // 2) - 1 - while low <= high: - mid = (low + high) // 2 - mid_key = self.__data[2 * mid] - if mid_key < key: - low = mid + 1 - elif mid_key > key: - high = mid - 1 - else: - return mid - return -(low + 1) + keys = [k for k, _ in self.__data] + index = bisect.bisect_left(keys, key) + if index < len(keys) and keys[index] == key: + return index + return -1 def plus(self, key: K, value: V) -> "ArrayMap[K, V]": index = self.__binary_search_key(key) if index >= 0: raise ValueError("Key already exists") - insert_at = -(index + 1) new_data = self.__data[:] - new_data[insert_at * 2 : insert_at * 2] = [key, value] - return ArrayMap(new_data) + bisect.insort_left(new_data, (key, value)) + return ArrayMap.__create(new_data) def minus_sorted_indices(self, indicesToRemove: List[int]) -> "ArrayMap[K, V]": if not indicesToRemove: return self - newData = [] - for i in range(0, len(self.__data), 2): - if i // 2 not in indicesToRemove: - newData.extend(self.__data[i : i + 2]) - return ArrayMap(newData) + new_data = [ + item for i, item in enumerate(self.__data) if i not in indicesToRemove + ] + return ArrayMap.__create(new_data) From fe16c26207969ff220db272e4902252302aaee43 Mon Sep 17 00:00:00 2001 From: Harvir Sahota Date: Tue, 2 Apr 2024 13:31:33 -0700 Subject: [PATCH 2/5] Fix changes --- python/selfie-lib/selfie_lib/ArrayMap.py | 168 ++++++++++++----------- 1 file changed, 88 insertions(+), 80 deletions(-) diff --git a/python/selfie-lib/selfie_lib/ArrayMap.py b/python/selfie-lib/selfie_lib/ArrayMap.py index ec1820d7..4a0b79c2 100644 --- a/python/selfie-lib/selfie_lib/ArrayMap.py +++ b/python/selfie-lib/selfie_lib/ArrayMap.py @@ -1,33 +1,27 @@ -from abc import ABC, abstractmethod from collections.abc import Set, Iterator, Mapping -from typing import List, TypeVar, Union, Any, Tuple -import bisect - - -class Comparable: - def __lt__(self, other: Any) -> bool: - return NotImplemented - - def __le__(self, other: Any) -> bool: - return NotImplemented - - def __gt__(self, other: Any) -> bool: - return NotImplemented - - def __ge__(self, other: Any) -> bool: - return NotImplemented - +from typing import List, TypeVar, Union, Any +from abc import abstractmethod, ABC +from functools import total_ordering T = TypeVar("T") V = TypeVar("V") -K = TypeVar("K", bound=Comparable) +K = TypeVar("K") -def string_slash_first_comparator(a: Any, b: Any) -> int: - """Special comparator for strings where '/' is considered the lowest.""" - if isinstance(a, str) and isinstance(b, str): - return (a.replace("/", "\0"), a) < (b.replace("/", "\0"), b) - return (a < b) - (a > b) +@total_ordering +class Comparable: + def __init__(self, value): + self.value = value + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Comparable): + return NotImplemented + return self.value < other.value + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Comparable): + return NotImplemented + return self.value == other.value class ListBackedSet(Set[T], ABC): @@ -37,15 +31,25 @@ def __len__(self) -> int: ... @abstractmethod def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ... - def __contains__(self, item: object) -> bool: - try: - index = self.__binary_search(item) - except ValueError: - return False - return index >= 0 - - @abstractmethod - def __binary_search(self, item: Any) -> int: ... + def __contains__(self, item: Any) -> bool: + return self._binary_search(item) >= 0 + + def _binary_search(self, item: Any) -> int: + low = 0 + high = len(self) - 1 + while low <= high: + mid = (low + high) // 2 + try: + mid_val = self[mid] + if mid_val < item: + low = mid + 1 + elif mid_val > item: + high = mid - 1 + else: + return mid # item found + except TypeError: + raise ValueError(f"Cannot compare items due to a type mismatch.") + return -(low + 1) # item not found class ArraySet(ListBackedSet[K]): @@ -60,6 +64,9 @@ def __create(cls, data: List[K]) -> "ArraySet[K]": instance.__data = data return instance + def __iter__(self) -> Iterator[K]: + return iter(self.__data) + @classmethod def empty(cls) -> "ArraySet[K]": if not hasattr(cls, "__EMPTY"): @@ -72,74 +79,75 @@ def __len__(self) -> int: def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]: return self.__data[index] - def __binary_search(self, item: K) -> int: - if isinstance(item, str): - key = lambda x: x.replace("/", "\0") - return ( - bisect.bisect_left(self.__data, item, key=key) - 1 - if item in self.__data - else -1 - ) - return bisect.bisect_left(self.__data, item) - 1 if item in self.__data else -1 - def plusOrThis(self, element: K) -> "ArraySet[K]": - index = self.__binary_search(element) - if index >= 0: + if element in self: return self - new_data = self.__data[:] - bisect.insort_left(new_data, element) - return ArraySet.__create(new_data) + else: + new_data = self.__data[:] + new_data.append(element) + new_data.sort(key=Comparable) + return ArraySet.__create(new_data) class ArrayMap(Mapping[K, V]): - __data: List[Tuple[K, V]] - - def __init__(self): - raise NotImplementedError("Use ArrayMap.empty() or other class methods instead") - - @classmethod - def __create(cls, data: List[Tuple[K, V]]) -> "ArrayMap[K, V]": - instance = super().__new__(cls) - instance.__data = data - return instance + def __init__(self, data=None): + if data is None: + self.__data = [] + else: + self.__data = data @classmethod def empty(cls) -> "ArrayMap[K, V]": if not hasattr(cls, "__EMPTY"): - cls.__EMPTY = cls.__create([]) + cls.__EMPTY = cls([]) return cls.__EMPTY def __getitem__(self, key: K) -> V: - index = self.__binary_search_key(key) + index = self._binary_search_key(key) if index >= 0: - return self.__data[index][1] + return self.__data[2 * index + 1] raise KeyError(key) def __iter__(self) -> Iterator[K]: - return (key for key, _ in self.__data) + return (self.__data[i] for i in range(0, len(self.__data), 2)) def __len__(self) -> int: - return len(self.__data) - - def __binary_search_key(self, key: K) -> int: - keys = [k for k, _ in self.__data] - index = bisect.bisect_left(keys, key) - if index < len(keys) and keys[index] == key: - return index - return -1 + return len(self.__data) // 2 + + def _binary_search_key(self, key: K) -> int: + def compare(a, b): + """Comparator that puts '/' first in strings.""" + if isinstance(a, str) and isinstance(b, str): + a, b = a.replace("/", "\0"), b.replace("/", "\0") + return (a > b) - (a < b) + + low, high = 0, len(self.__data) // 2 - 1 + while low <= high: + mid = (low + high) // 2 + mid_key = self.__data[2 * mid] + comparison = compare(mid_key, key) + if comparison < 0: + low = mid + 1 + elif comparison > 0: + high = mid - 1 + else: + return mid # key found + return -(low + 1) # key not found def plus(self, key: K, value: V) -> "ArrayMap[K, V]": - index = self.__binary_search_key(key) + index = self._binary_search_key(key) if index >= 0: raise ValueError("Key already exists") + insert_at = -(index + 1) new_data = self.__data[:] - bisect.insort_left(new_data, (key, value)) - return ArrayMap.__create(new_data) + new_data.insert(insert_at * 2, key) + new_data.insert(insert_at * 2 + 1, value) + return ArrayMap(new_data) - def minus_sorted_indices(self, indicesToRemove: List[int]) -> "ArrayMap[K, V]": - if not indicesToRemove: - return self - new_data = [ - item for i, item in enumerate(self.__data) if i not in indicesToRemove - ] - return ArrayMap.__create(new_data) + def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]": + new_data = self.__data[:] + adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices] + adjusted_indices.sort() + for index in reversed(adjusted_indices): + del new_data[index] + return ArrayMap(new_data) From 1845272b64de1e19c5013aba46721e21101af1b1 Mon Sep 17 00:00:00 2001 From: Harvir Sahota Date: Tue, 2 Apr 2024 19:53:19 -0700 Subject: [PATCH 3/5] Centralize binary search --- python/selfie-lib/selfie_lib/ArrayMap.py | 117 +++++++++++------------ 1 file changed, 55 insertions(+), 62 deletions(-) diff --git a/python/selfie-lib/selfie_lib/ArrayMap.py b/python/selfie-lib/selfie_lib/ArrayMap.py index 4a0b79c2..beb81e0a 100644 --- a/python/selfie-lib/selfie_lib/ArrayMap.py +++ b/python/selfie-lib/selfie_lib/ArrayMap.py @@ -1,27 +1,41 @@ from collections.abc import Set, Iterator, Mapping -from typing import List, TypeVar, Union, Any +from typing import List, TypeVar, Union, Any, Callable, Optional, Generator from abc import abstractmethod, ABC -from functools import total_ordering T = TypeVar("T") V = TypeVar("V") K = TypeVar("K") -@total_ordering -class Comparable: - def __init__(self, value): - self.value = value +class BinarySearchUtil: + @staticmethod + def binary_search( + data, item, compare_func: Optional[Callable[[Any, Any], int]] = None + ) -> int: + low, high = 0, len(data) - 1 + while low <= high: + mid = (low + high) // 2 + mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid] + comparison = ( + compare_func(mid_val, item) + if compare_func + else (mid_val > item) - (mid_val < item) + ) - def __lt__(self, other: Any) -> bool: - if not isinstance(other, Comparable): - return NotImplemented - return self.value < other.value + if comparison < 0: + low = mid + 1 + elif comparison > 0: + high = mid - 1 + else: + return mid # item found + return -(low + 1) # item not found - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Comparable): - return NotImplemented - return self.value == other.value + @staticmethod + def default_compare(a: Any, b: Any) -> int: + """Default comparison function for binary search, with special handling for strings.""" + if isinstance(a, str) and isinstance(b, str): + a, b = a.replace("/", "\0"), b.replace("/", "\0") + return (a > b) - (a < b) class ListBackedSet(Set[T], ABC): @@ -31,25 +45,14 @@ def __len__(self) -> int: ... @abstractmethod def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ... + @abstractmethod + def __iter__(self) -> Iterator[T]: ... + def __contains__(self, item: Any) -> bool: return self._binary_search(item) >= 0 def _binary_search(self, item: Any) -> int: - low = 0 - high = len(self) - 1 - while low <= high: - mid = (low + high) // 2 - try: - mid_val = self[mid] - if mid_val < item: - low = mid + 1 - elif mid_val > item: - high = mid - 1 - else: - return mid # item found - except TypeError: - raise ValueError(f"Cannot compare items due to a type mismatch.") - return -(low + 1) # item not found + return BinarySearchUtil.binary_search(self, item) class ArraySet(ListBackedSet[K]): @@ -80,59 +83,49 @@ def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]: return self.__data[index] def plusOrThis(self, element: K) -> "ArraySet[K]": - if element in self: + index = self._binary_search(element) + if index >= 0: return self else: + insert_at = -(index + 1) new_data = self.__data[:] - new_data.append(element) - new_data.sort(key=Comparable) + new_data.insert(insert_at, element) return ArraySet.__create(new_data) class ArrayMap(Mapping[K, V]): - def __init__(self, data=None): - if data is None: - self.__data = [] - else: - self.__data = data + __data: List[Union[K, V]] + + def __init__(self): + raise NotImplementedError("Use ArrayMap.empty() or other class methods instead") + + @classmethod + def __create(cls, data: List[Union[K, V]]) -> "ArrayMap[K, V]": + instance = cls.__new__(cls) + instance.__data = data + return instance @classmethod def empty(cls) -> "ArrayMap[K, V]": if not hasattr(cls, "__EMPTY"): - cls.__EMPTY = cls([]) + cls.__EMPTY = cls.__create([]) return cls.__EMPTY def __getitem__(self, key: K) -> V: index = self._binary_search_key(key) if index >= 0: - return self.__data[2 * index + 1] + return self.__data[2 * index + 1] # type: ignore raise KeyError(key) def __iter__(self) -> Iterator[K]: - return (self.__data[i] for i in range(0, len(self.__data), 2)) + return (self.__data[i] for i in range(0, len(self.__data), 2)) # type: ignore def __len__(self) -> int: return len(self.__data) // 2 def _binary_search_key(self, key: K) -> int: - def compare(a, b): - """Comparator that puts '/' first in strings.""" - if isinstance(a, str) and isinstance(b, str): - a, b = a.replace("/", "\0"), b.replace("/", "\0") - return (a > b) - (a < b) - - low, high = 0, len(self.__data) // 2 - 1 - while low <= high: - mid = (low + high) // 2 - mid_key = self.__data[2 * mid] - comparison = compare(mid_key, key) - if comparison < 0: - low = mid + 1 - elif comparison > 0: - high = mid - 1 - else: - return mid # key found - return -(low + 1) # key not found + keys = [self.__data[i] for i in range(0, len(self.__data), 2)] + return BinarySearchUtil.binary_search(keys, key) def plus(self, key: K, value: V) -> "ArrayMap[K, V]": index = self._binary_search_key(key) @@ -142,12 +135,12 @@ def plus(self, key: K, value: V) -> "ArrayMap[K, V]": new_data = self.__data[:] new_data.insert(insert_at * 2, key) new_data.insert(insert_at * 2 + 1, value) - return ArrayMap(new_data) + return ArrayMap.__create(new_data) def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]": new_data = self.__data[:] adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices] - adjusted_indices.sort() - for index in reversed(adjusted_indices): + adjusted_indices.sort(reverse=True) + for index in adjusted_indices: del new_data[index] - return ArrayMap(new_data) + return ArrayMap.__create(new_data) From d6b542f60b7328e619af3d7882ed51ccbedf0599 Mon Sep 17 00:00:00 2001 From: Harvir Sahota Date: Wed, 3 Apr 2024 10:20:45 -0700 Subject: [PATCH 4/5] Update binary search --- python/selfie-lib/selfie_lib/ArrayMap.py | 61 +++++++++++------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/python/selfie-lib/selfie_lib/ArrayMap.py b/python/selfie-lib/selfie_lib/ArrayMap.py index beb81e0a..5cd9fd57 100644 --- a/python/selfie-lib/selfie_lib/ArrayMap.py +++ b/python/selfie-lib/selfie_lib/ArrayMap.py @@ -1,5 +1,5 @@ from collections.abc import Set, Iterator, Mapping -from typing import List, TypeVar, Union, Any, Callable, Optional, Generator +from typing import List, TypeVar, Union, Any from abc import abstractmethod, ABC T = TypeVar("T") @@ -7,35 +7,32 @@ K = TypeVar("K") -class BinarySearchUtil: - @staticmethod - def binary_search( - data, item, compare_func: Optional[Callable[[Any, Any], int]] = None - ) -> int: - low, high = 0, len(data) - 1 - while low <= high: - mid = (low + high) // 2 - mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid] - comparison = ( - compare_func(mid_val, item) - if compare_func - else (mid_val > item) - (mid_val < item) - ) - - if comparison < 0: - low = mid + 1 - elif comparison > 0: - high = mid - 1 - else: - return mid # item found - return -(low + 1) # item not found - - @staticmethod - def default_compare(a: Any, b: Any) -> int: - """Default comparison function for binary search, with special handling for strings.""" - if isinstance(a, str) and isinstance(b, str): - a, b = a.replace("/", "\0"), b.replace("/", "\0") - return (a > b) - (a < b) +def _compare_normal(a, b) -> int: + if a == b: + return 0 + elif a < b: + return -1 + else: + return 1 + +def _compare_string_slash_first(a: str, b: str) -> int: + return _compare_normal(a.replace("/", "\0"), b.replace("/", "\0")) + +def _binary_search(data, item) -> int: + compare_func = _compare_string_slash_first if isinstance(item, str) else _compare_normal + low, high = 0, len(data) - 1 + while low <= high: + mid = (low + high) // 2 + mid_val = data[mid] + comparison = compare_func(mid_val, item) + + if comparison < 0: + low = mid + 1 + elif comparison > 0: + high = mid - 1 + else: + return mid # item found + return -(low + 1) # item not found class ListBackedSet(Set[T], ABC): @@ -52,7 +49,7 @@ def __contains__(self, item: Any) -> bool: return self._binary_search(item) >= 0 def _binary_search(self, item: Any) -> int: - return BinarySearchUtil.binary_search(self, item) + return _binary_search(self, item) class ArraySet(ListBackedSet[K]): @@ -125,7 +122,7 @@ def __len__(self) -> int: def _binary_search_key(self, key: K) -> int: keys = [self.__data[i] for i in range(0, len(self.__data), 2)] - return BinarySearchUtil.binary_search(keys, key) + return _binary_search(keys, key) def plus(self, key: K, value: V) -> "ArrayMap[K, V]": index = self._binary_search_key(key) From 758cf81488b02e8750b3e3f803f230cb2041bede Mon Sep 17 00:00:00 2001 From: Harvir Sahota Date: Wed, 3 Apr 2024 10:23:06 -0700 Subject: [PATCH 5/5] ruff format --- python/selfie-lib/selfie_lib/ArrayMap.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/selfie-lib/selfie_lib/ArrayMap.py b/python/selfie-lib/selfie_lib/ArrayMap.py index 5cd9fd57..0f307f43 100644 --- a/python/selfie-lib/selfie_lib/ArrayMap.py +++ b/python/selfie-lib/selfie_lib/ArrayMap.py @@ -15,11 +15,15 @@ def _compare_normal(a, b) -> int: else: return 1 + def _compare_string_slash_first(a: str, b: str) -> int: return _compare_normal(a.replace("/", "\0"), b.replace("/", "\0")) + def _binary_search(data, item) -> int: - compare_func = _compare_string_slash_first if isinstance(item, str) else _compare_normal + compare_func = ( + _compare_string_slash_first if isinstance(item, str) else _compare_normal + ) low, high = 0, len(data) - 1 while low <= high: mid = (low + high) // 2