diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 32fa391bb0e2..7383a2b69610 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3145,7 +3145,7 @@ def visit_cast_expr(self, expr: CastExpr) -> Type: return target_type def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type: - source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form), + source_type = self.accept(expr.expr, type_context=self.type_context[-1], allow_none_return=True, always_allow_any=True) target_type = expr.type if not is_same_type(source_type, target_type): diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 84b6105170bd..fd10b82cc558 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -1049,6 +1049,20 @@ assert_type(a, Any) # E: Expression is of type "int", not "Any" assert_type(a, Literal[1]) # E: Expression is of type "int", not "Literal[1]" [builtins fixtures/tuple.pyi] +[case testAssertTypeGeneric] +from typing import assert_type, TypeVar, Generic +from typing_extensions import Literal +T = TypeVar("T") +def f(x: T) -> T: return x +assert_type(f(1), int) +class Gen(Generic[T]): + def __new__(cls, obj: T) -> Gen[T]: ... +assert_type(Gen(1), Gen[int]) +# With type context, it infers Gen[Literal[1]] instead. +y: Gen[Literal[1]] = assert_type(Gen(1), Gen[Literal[1]]) + +[builtins fixtures/tuple.pyi] + -- None return type -- ----------------