Skip to content

Commit

Permalink
safe generic thanks to python/typing#498
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrmaslanka committed Mar 10, 2020
1 parent a52ed9d commit 813ee95
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
15 changes: 15 additions & 0 deletions satella/coding/_safe_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
This module exports a safe generic type and a type variable T to get around the bug in Python 3.5
https://github.com/python/typing/issues/498
"""
import typing as tp
import sys

__all__ = ['T', 'SafeGeneric']

T = tp.TypeVar('T')

if sys.version_info[1] == 5:
SafeGeneric = object
else:
SafeGeneric = tp.Generic[T]
4 changes: 2 additions & 2 deletions satella/coding/structures/dict_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from satella.configuration.schema import Descriptor, descriptor_from_dict
from satella.exceptions import ConfigurationValidationError
from .._safe_typing import SafeGeneric, T

__all__ = ['DictObject', 'apply_dict_object']

T = tp.TypeVar('T')
U = tp.TypeVar('U')


class DictObject(dict, tp.Generic[T]):
class DictObject(dict, SafeGeneric):
"""
A dictionary wrapper that can be accessed by attributes. Eg:
Expand Down
50 changes: 24 additions & 26 deletions satella/coding/structures/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import ABCMeta, abstractmethod

from ..decorators import wraps
from .._safe_typing import SafeGeneric, T

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,10 +80,7 @@ def inner(self, a, *args, **kwargs):
return inner


HeapVar = tp.TypeVar('T')


class Heap(collections.UserList, tp.Generic[HeapVar]):
class Heap(collections.UserList, SafeGeneric):
"""
Sane heap as object - not like heapq.
Expand All @@ -92,16 +90,16 @@ class Heap(collections.UserList, tp.Generic[HeapVar]):
Not thread-safe
"""

def __init__(self, from_list: tp.Optional[tp.Iterable[HeapVar]] = None):
def __init__(self, from_list: tp.Optional[tp.Iterable[T]] = None):
super().__init__(from_list)
heapq.heapify(self.data)

def push_many(self, items: tp.Iterable[HeapVar]) -> None:
def push_many(self, items: tp.Iterable[T]) -> None:
for item in items:
self.push(item)

@_extras_to_one
def push(self, item: HeapVar) -> None:
def push(self, item: T) -> None:
"""
Use it like:
Expand All @@ -124,19 +122,19 @@ def __deepcopy__(self, memo) -> 'Heap':
def __copy__(self) -> 'Heap':
return self.__copy(copy.copy)

def __iter__(self) -> tp.Iterator[HeapVar]:
def __iter__(self) -> tp.Iterator[T]:
return self.data.__iter__()

def pop(self) -> HeapVar:
def pop(self) -> T:
"""
Return smallest element of the heap.
:raises IndexError: on empty heap
"""
return heapq.heappop(self.data)

def filter_map(self, filter_fun: tp.Optional[tp.Callable[[HeapVar], bool]] = None,
map_fun: tp.Optional[tp.Callable[[HeapVar], tp.Any]] = None):
def filter_map(self, filter_fun: tp.Optional[tp.Callable[[T], bool]] = None,
map_fun: tp.Optional[tp.Callable[[T], tp.Any]] = None):
"""
Get only items that return True when condition(item) is True. Apply a
transform: item' = item(condition) on
Expand All @@ -154,7 +152,7 @@ def __bool__(self) -> bool:
"""
return len(self.data) > 0

def iter_ascending(self) -> tp.Iterable[HeapVar]:
def iter_ascending(self) -> tp.Iterable[T]:
"""
Return an iterator returning all elements in this heap sorted ascending.
State of the heap is not changed
Expand All @@ -163,7 +161,7 @@ def iter_ascending(self) -> tp.Iterable[HeapVar]:
while heap:
yield heapq.heappop(heap)

def iter_descending(self) -> tp.Iterable[HeapVar]:
def iter_descending(self) -> tp.Iterable[T]:
"""
Return an iterator returning all elements in this heap sorted descending.
State of the heap is not changed.
Expand All @@ -184,11 +182,11 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return u'<satella.coding.Heap>'

def __contains__(self, item: HeapVar) -> bool:
def __contains__(self, item: T) -> bool:
return item in self.data


class SetHeap(Heap, tp.Generic[HeapVar]):
class SetHeap(Heap, SafeGeneric):
"""
A heap with additional invariant that no two elements are the same.
Expand All @@ -197,25 +195,25 @@ class SetHeap(Heap, tp.Generic[HeapVar]):
#notthreadsafe
"""

def __init__(self, from_list: tp.Optional[tp.Iterable[HeapVar]] = None):
def __init__(self, from_list: tp.Optional[tp.Iterable[T]] = None):
super().__init__(from_list=from_list)
self.set = set(self.data)

def push(self, item: HeapVar):
def push(self, item: T):
if item not in self.set:
super().push(item)
self.set.add(item)

def pop(self) -> HeapVar:
def pop(self) -> T:
item = super().pop()
self.set.remove(item)
return item

def __contains__(self, item: HeapVar) -> bool:
def __contains__(self, item: T) -> bool:
return item in self.set

def filter_map(self, filter_fun: tp.Optional[tp.Callable[[HeapVar], bool]] = None,
map_fun: tp.Optional[tp.Callable[[HeapVar], tp.Any]] = None):
def filter_map(self, filter_fun: tp.Optional[tp.Callable[[T], bool]] = None,
map_fun: tp.Optional[tp.Callable[[T], tp.Any]] = None):
super().filter_map(filter_fun=filter_fun, map_fun=map_fun)
self.set = set(self.data)

Expand Down Expand Up @@ -245,7 +243,7 @@ class TimeBasedHeap(Heap):
def __repr__(self):
return '<satella.coding.TimeBasedHeap with %s elements>' % (len(self.data),)

def items(self) -> tp.Iterable[HeapVar]:
def items(self) -> tp.Iterable[T]:
"""
Return an iterable, but WITHOUT timestamps (only items), in
unspecified order
Expand All @@ -259,8 +257,8 @@ def __init__(self, default_clock_source: tp.Callable[[], int] = None):
self.default_clock_source = default_clock_source or time.monotonic
super().__init__(from_list=())

def put(self, timestamp_or_value: tp.Union[tp.Tuple[tp.Union[float, int], HeapVar]],
value: tp.Optional[HeapVar] = None) -> None:
def put(self, timestamp_or_value: tp.Union[tp.Tuple[tp.Union[float, int], T]],
value: tp.Optional[T] = None) -> None:
"""
Put an item on heap.
Expand All @@ -275,7 +273,7 @@ def put(self, timestamp_or_value: tp.Union[tp.Tuple[tp.Union[float, int], HeapVa
self.push((timestamp, item))

def pop_less_than(self, less: tp.Optional[tp.Union[int, float]] = None) -> tp.Generator[
HeapVar, None, None]:
T, None, None]:
"""
Return all elements less (sharp inequality) than particular value.
Expand All @@ -295,7 +293,7 @@ def pop_less_than(self, less: tp.Optional[tp.Union[int, float]] = None) -> tp.Ge
return
yield self.pop()

def remove(self, item: HeapVar) -> None:
def remove(self, item: T) -> None:
"""
Remove all things equal to item
"""
Expand Down

1 comment on commit 813ee95

@hynekcer
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably only an incompatible hack that doesn't work as a correct Generic type e.g with mypy 0.770 and Python 3.7
The error message in mypy is "MyClass" expects no type arguments, but 1 given if I use MyClass[T]

or 'Variable SafeGeneric" is not valid as a type if the variable SafeGeneric is moved to the same module

or Incompatible types in assignment (expression has type "MyClass[<nothing>]", variable has type "MyClass[T]")

A solution that works:

class A(tp.Generic[T]):
    __copy__ = None   # a fix for issue in Python 3.5 https://github.com/python/typing/issues/498
    ...

This is safe:

  1. The value __copy__ = None has the same effect for the function copy.copy() like an undefined __copy__ attribute. A possible problem is if it would overshadow the right __copy__ attribute.
  2. The "Generic" type is the last ancestor of MyClass. (The only possible problem would be if it is used as generic mixin for a class with a __copy__ method, but it can be easily solved by a super(...).__copy__())

Please sign in to comment.