Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
if isinstance(actual, Instance):
instance = actual
erased = erase_typevars(template)
assert isinstance(erased, Instance) # type: ignore[misc]
assert isinstance(erased, ProperType) and isinstance(erased, Instance)
# We always try nominal inference if possible,
# it is much faster than the structural one.
if self.direction == SUBTYPE_OF and template.type.has_base(instance.type.fullname):
Expand Down Expand Up @@ -996,7 +996,24 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
res.extend(cb)
return res
elif isinstance(actual, TupleType) and self.direction == SUPERTYPE_OF:
return infer_constraints(template, mypy.typeops.tuple_fallback(actual), self.direction)
instance = mypy.typeops.tuple_fallback(actual)
erased = erase_typevars(template)
assert isinstance(erased, ProperType) and isinstance(erased, Instance)
# Special-case protocols before using fallback to get more precise constraints
# for custom tuple types like NamedTuples.
if (
template.type.is_protocol
and self.direction == SUPERTYPE_OF
and not any(template == t for t in reversed(template.type.inferring))
and mypy.subtypes.is_protocol_implementation(instance, erased, skip=["__call__"])
):
template.type.inferring.append(template)
res = self.infer_constraints_from_protocol_members(
instance, template, original_actual, template
)
template.type.inferring.pop()
return res
return infer_constraints(template, instance, self.direction)
elif isinstance(actual, TypeVarType):
if not actual.values and not actual.id.is_meta_var():
return infer_constraints(template, actual.upper_bound, self.direction)
Expand Down
23 changes: 17 additions & 6 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,8 +2220,9 @@ def report_protocol_problems(
class_obj = False
is_module = False
skip = []
original_subtype = subtype
if isinstance(subtype, TupleType):
subtype = subtype.partial_fallback
subtype = mypy.typeops.tuple_fallback(subtype)
elif isinstance(subtype, TypedDictType):
subtype = subtype.fallback
elif isinstance(subtype, TypeType):
Expand All @@ -2233,7 +2234,7 @@ def report_protocol_problems(
if subtype.is_type_obj():
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
ret_type = mypy.typeops.tuple_fallback(ret_type)
if not isinstance(ret_type, Instance):
return
class_obj = True
Expand All @@ -2243,6 +2244,10 @@ def report_protocol_problems(
skip = ["__call__"]
if subtype.extra_attrs and subtype.extra_attrs.mod_name:
is_module = True
if not isinstance(original_subtype, TupleType):
# Apart from instances, only tuples are supported by
# is_protocol_implementation() for now.
original_subtype = subtype

# Report missing members
missing = get_missing_protocol_members(subtype, supertype, skip=skip)
Expand Down Expand Up @@ -2274,7 +2279,7 @@ def report_protocol_problems(

# Report member type conflicts
conflict_types = get_conflict_protocol_types(
subtype, supertype, class_obj=class_obj, options=self.options
subtype, original_subtype, supertype, class_obj=class_obj, options=self.options
)
if conflict_types and (
not is_subtype(subtype, erase_type(supertype), options=self.options)
Expand Down Expand Up @@ -3191,7 +3196,11 @@ def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str


def get_conflict_protocol_types(
left: Instance, right: Instance, class_obj: bool = False, options: Options | None = None
left: Instance,
original_left: Type,
right: Instance,
class_obj: bool = False,
options: Options | None = None,
) -> list[tuple[str, Type, Type, bool]]:
"""Find members that are defined in 'left' but have incompatible types.
Return them as a list of ('member', 'got', 'expected', 'is_lvalue').
Expand All @@ -3203,7 +3212,7 @@ def get_conflict_protocol_types(
continue
supertype = find_member(member, right, left)
assert supertype is not None
subtype = get_protocol_member(left, member, class_obj)
subtype = get_protocol_member(left, original_left, member, class_obj)
if not subtype:
continue
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True, options=options)
Expand All @@ -3219,7 +3228,9 @@ def get_conflict_protocol_types(
different_setter = True
supertype = set_supertype
if IS_EXPLICIT_SETTER in get_member_flags(member, left):
set_subtype = get_protocol_member(left, member, class_obj, is_lvalue=True)
set_subtype = get_protocol_member(
left, original_left, member, class_obj, is_lvalue=True
)
if set_subtype and not is_same_type(set_subtype, subtype):
different_setter = True
subtype = set_subtype
Expand Down
27 changes: 20 additions & 7 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,12 @@ def visit_tuple_type(self, left: TupleType) -> bool:
mypy.typeops.tuple_fallback(left), right
):
return True
elif right.type.is_protocol and is_protocol_implementation(
left, right, proper_subtype=self.proper_subtype
):
# Special-case protocols to get precise binding of self type for
# custom tuple types like NamedTuples.
return True
return False
elif isinstance(right, TupleType):
# If right has a variadic unpack this needs special handling. If there is a TypeVarTuple
Expand Down Expand Up @@ -1185,7 +1191,7 @@ def pop_on_exit(stack: list[tuple[T, T]], left: T, right: T) -> Iterator[None]:


def is_protocol_implementation(
left: Instance,
left: Instance | TupleType,
right: Instance,
proper_subtype: bool = False,
class_obj: bool = False,
Expand All @@ -1212,6 +1218,11 @@ def f(self) -> A: ...
assert right.type.is_protocol
if skip is None:
skip = []
# Preserve original left type for precise self-type binding. Only tuple types are
# supported for now.
original_left = left
if isinstance(left, TupleType):
left = mypy.typeops.tuple_fallback(left)
# We need to record this check to generate protocol fine-grained dependencies.
type_state.record_protocol_subtype_check(left.type, right.type)
# nominal subtyping currently ignores '__init__' and '__new__' signatures
Expand All @@ -1234,10 +1245,10 @@ def f(self) -> A: ...
ignore_names = member != "__call__" # __call__ can be passed kwargs
# The third argument below indicates to what self type is bound.
# We always bind self to the subtype. (Similarly to nominal types).
supertype = find_member(member, right, left)
supertype = find_member(member, right, original_left)
assert supertype is not None

subtype = get_protocol_member(left, member, class_obj)
subtype = get_protocol_member(left, original_left, member, class_obj)
# Useful for debugging:
# print(member, 'of', left, 'has type', subtype)
# print(member, 'of', right, 'has type', supertype)
Expand All @@ -1264,9 +1275,11 @@ def f(self) -> A: ...
if IS_SETTABLE in superflags:
# Check opposite direction for settable attributes.
if IS_EXPLICIT_SETTER in superflags:
supertype = find_member(member, right, left, is_lvalue=True)
supertype = find_member(member, right, original_left, is_lvalue=True)
if IS_EXPLICIT_SETTER in subflags:
subtype = get_protocol_member(left, member, class_obj, is_lvalue=True)
subtype = get_protocol_member(
left, original_left, member, class_obj, is_lvalue=True
)
# At this point we know attribute is present on subtype, otherwise we
# would return False above.
assert supertype is not None and subtype is not None
Expand Down Expand Up @@ -1305,7 +1318,7 @@ def f(self) -> A: ...


def get_protocol_member(
left: Instance, member: str, class_obj: bool, is_lvalue: bool = False
left: Instance, original_left: Type, member: str, class_obj: bool, is_lvalue: bool = False
) -> Type | None:
if member == "__call__" and class_obj:
# Special case: class objects always have __call__ that is just the constructor.
Expand All @@ -1316,7 +1329,7 @@ def get_protocol_member(
# if constructor signature didn't match, this can cause many false negatives.
return None

subtype = find_member(member, left, left, class_obj=class_obj, is_lvalue=is_lvalue)
subtype = find_member(member, left, original_left, class_obj=class_obj, is_lvalue=is_lvalue)
if isinstance(subtype, PartialType):
subtype = (
NoneType()
Expand Down
30 changes: 30 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -4779,3 +4779,33 @@ class A(Protocol):
pass

[builtins fixtures/tuple.pyi]

[case testTupleTypeSelfTypeProto]
from typing import Protocol, TypeVar

R = TypeVar("R", covariant=True)

class P(Protocol[R]):
def rep(self) -> R: ...

T = TypeVar("T")
def rep(x: P[T]) -> T: ...

class C(tuple[int, str]):
def rep(self: T) -> T: ...

t: C
reveal_type(t) # N: Revealed type is "tuple[builtins.int, builtins.str, fallback=__main__.C]"
reveal_type(rep(t)) # N: Revealed type is "tuple[builtins.int, builtins.str, fallback=__main__.C]"

def ok_rep(x: P[tuple[int, str]]) -> None: ...
ok_rep(t)

def bad_rep(x: P[tuple[str, int]]) -> None: ...
bad_rep(t) # E: Argument 1 to "bad_rep" has incompatible type "C"; expected "P[tuple[str, int]]" \
# N: Following member(s) of "C" have conflicts: \
# N: Expected: \
# N: def rep(self) -> tuple[str, int] \
# N: Got: \
# N: def rep(self) -> C
[builtins fixtures/tuple.pyi]
Loading