Skip to content

Commit

Permalink
Implement type-aware get for TypedDict
Browse files Browse the repository at this point in the history
Previously, `get` would simply fallback to the type of the underlying dictionary which made
TypedDicts hard to use with code that's parsing objects where fields may or may not be
present (for example, parsing a response).

This implementation _explicitly_ ignores the default parameter's type as it's quite useful to
chain together get calls (Until something like PEP 505 hits 😄)

```python
foo.get('a', {}).get('b', {}).get('c')
```

This fixes python#2612
  • Loading branch information
Roy Williams committed Jan 12, 2017
1 parent 4ca1709 commit 4c3cb1b
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 8 deletions.
10 changes: 8 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
get_typ_args, set_typ_args,
)
TypedDictGetFunction)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
Expand Down Expand Up @@ -341,6 +341,10 @@ def check_call(self, callee: Type, args: List[Expression],
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if isinstance(callee, TypedDictGetFunction):
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
return return_type, callee
if callee.is_concrete_type_obj() and callee.type_object().is_abstract:
type = callee.type_object()
self.msg.cannot_instantiate_abstract_class(
Expand Down Expand Up @@ -1484,11 +1488,13 @@ def _get_value(self, index: Expression) -> Optional[int]:
return None

def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
return self.get_typeddict_index_type(td_type, index)

def get_typeddict_index_type(self, td_type: TypedDictType, index: Expression) -> Type:
if not isinstance(index, (StrExpr, UnicodeExpr)):
self.msg.typeddict_item_name_must_be_string_literal(td_type, index)
return AnyType()
item_name = index.value

item_type = td_type.items.get(item_name)
if item_type is None:
self.msg.typeddict_item_name_not_found(td_type, item_name, index)
Expand Down
13 changes: 8 additions & 5 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from mypy.types import (
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
Overloaded, TypeVarType, UnionType, PartialType,
DeletedType, NoneTyp, TypeType, function_type
)
DeletedType, NoneTyp, TypeType, function_type,
TypedDictGetFunction)
from mypy.nodes import (
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
ARG_POS, ARG_STAR, ARG_STAR2,
Expand Down Expand Up @@ -120,9 +120,12 @@ def analyze_member_access(name: str,
original_type=original_type, chk=chk)
elif isinstance(typ, TypedDictType):
# Actually look up from the fallback instance type.
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
is_operator, builtin_type, not_ready_callback, msg,
original_type=original_type, chk=chk)
result = analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
is_operator, builtin_type, not_ready_callback, msg,
original_type=original_type, chk=chk)
if name == 'get' and isinstance(result, CallableType):
result = TypedDictGetFunction(typ, result)
return result
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
# Class attribute.
# TODO super?
Expand Down
20 changes: 20 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,26 @@ def zipall(self, right: 'TypedDictType') \
yield (item_name, None, right_item_type)


class TypedDictGetFunction(CallableType):
"""A special callable type containing a reference to the TypedDict `get` callable instance.
This is needed to delay determining the signature of a TypedDict's `get` method until the
method is actually called. This allows `get` to behave just as indexing into the TypedDict
would.
This is not a real type, but is needed to allow TypedDict.get to behave as expected.
"""
def __init__(self, typed_dict: TypedDictType, fallback_callable: CallableType) -> None:
super().__init__(fallback_callable.arg_types, fallback_callable.arg_kinds,
fallback_callable.arg_names, fallback_callable.ret_type,
fallback_callable.fallback, fallback_callable.name,
fallback_callable.definition, fallback_callable.variables,
fallback_callable.line, fallback_callable.column,
fallback_callable.is_ellipsis_args, fallback_callable.implicit,
fallback_callable.is_classmethod_class, fallback_callable.special_sig)
self.typed_dict = typed_dict
self.fallback_callable = fallback_callable


class StarType(Type):
"""The star type *type_parameter.
Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,38 @@ def set_coordinate(p: TaggedPoint, key: str, value: int) -> None:

-- Special Method: get

[case testCanUseGetMethodWithStringLiteralKey]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
reveal_type(p.get('type')) # E: Revealed type is 'builtins.str'
reveal_type(p.get('x')) # E: Revealed type is 'builtins.int'
reveal_type(p.get('y')) # E: Revealed type is 'builtins.int'
[builtins fixtures/dict.pyi]

[case testCannotGetMethodWithInvalidStringLiteralKey]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
p.get('z') # E: 'z' is not a valid item name; expected one of ['type', 'x', 'y']
[builtins fixtures/dict.pyi]

[case testGetMethodWithVariableKeyFallsBack]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
key = 'type'
reveal_type(p.get(key)) # E: Revealed type is 'builtins.object*'
[builtins fixtures/dict.pyi]

[case testChainedGetMethodWithFallback]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
reveal_type(p.get('first_point', {}).get('x')) # E: Revealed type is 'builtins.int'
[builtins fixtures/dict.pyi]

-- TODO: Implement support for these cases:
--[case testGetOfTypedDictWithValidStringLiteralKeyReturnsPreciseType]
--[case testGetOfTypedDictWithInvalidStringLiteralKeyIsError]
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]):
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def get(self, k: KT, default: VT=None) -> VT: pass
def update(self, a: Mapping[KT, VT]) -> None: pass

class int: # for convenience
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/lib-stub/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ class Sequence(Iterable[T], Generic[T]):
@abstractmethod
def __getitem__(self, n: Any) -> T: pass

class Mapping(Generic[T, U]): pass
class Mapping(Generic[T, U]):
@abstractmethod
def get(self, k: T, default: U=None) -> U: pass

class MutableMapping(Generic[T, U]): pass

Expand Down

0 comments on commit 4c3cb1b

Please sign in to comment.