diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index b5081f113f91..00f9b6589f65 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -774,6 +774,19 @@ of the above sections. f(memoryview(b"")) # Ok +.. option:: --disallow-str-iteration + + Disallow iterating over ``str`` values. + This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected. + To iterate over characters, call ``iter`` on the string explicitly. + + .. code-block:: python + + s = "hello" + for ch in s: # error: Iterating over "str" is disallowed + print(ch) + + .. option:: --extra-checks This flag enables additional checks that are technically correct but may be diff --git a/docs/source/config_file.rst b/docs/source/config_file.rst index 77f952471007..41c15536230d 100644 --- a/docs/source/config_file.rst +++ b/docs/source/config_file.rst @@ -852,6 +852,14 @@ section of the command line docs. Disable treating ``bytearray`` and ``memoryview`` as subtypes of ``bytes``. This will be enabled by default in *mypy 2.0*. +.. confval:: disallow_str_iteration + + :type: boolean + :default: False + + Disallow iterating over ``str`` values. + This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected. + .. confval:: strict :type: boolean diff --git a/misc/typeshed_patches/0001-Add-explicit-overload-for-iter-of-str.patch b/misc/typeshed_patches/0001-Add-explicit-overload-for-iter-of-str.patch new file mode 100644 index 000000000000..f5683e4e62eb --- /dev/null +++ b/misc/typeshed_patches/0001-Add-explicit-overload-for-iter-of-str.patch @@ -0,0 +1,13 @@ +diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi +index bd425ff3c..61a28e3a9 100644 +--- a/mypy/typeshed/stdlib/builtins.pyi ++++ b/mypy/typeshed/stdlib/builtins.pyi +@@ -1455,6 +1455,8 @@ def input(prompt: object = "", /) -> str: ... + class _GetItemIterable(Protocol[_T_co]): + def __getitem__(self, i: int, /) -> _T_co: ... + ++@overload ++def iter(object: str, /) -> Iterator[str]: ... + @overload + def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ... + @overload diff --git a/mypy/checker.py b/mypy/checker.py index 008becdd3483..796110f7f8ec 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5378,6 +5378,10 @@ def analyze_iterable_item_type_without_expression( echk = self.expr_checker iterable: Type iterable = get_proper_type(type) + + if self.options.disallow_str_iteration and self.is_str_iteration_type(iterable): + self.msg.str_iteration_disallowed(context) + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0] if ( @@ -5390,6 +5394,18 @@ def analyze_iterable_item_type_without_expression( iterable = echk.check_method_call_by_name("__next__", iterator, [], [], context)[0] return iterator, iterable + def is_str_iteration_type(self, typ: Type) -> bool: + typ = get_proper_type(typ) + if isinstance(typ, LiteralType): + return isinstance(typ.value, str) + if isinstance(typ, Instance): + return typ.type.fullname == "builtins.str" + if isinstance(typ, UnionType): + return any(self.is_str_iteration_type(item) for item in typ.relevant_items()) + if isinstance(typ, TypeVarType): + return self.is_str_iteration_type(typ.upper_bound) + return False + def analyze_range_native_int_type(self, expr: Expression) -> Type | None: """Try to infer native int item type from arguments to range(...). diff --git a/mypy/main.py b/mypy/main.py index 926e72515d95..c94a23dc79a2 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -937,6 +937,14 @@ def add_invertible_flag( group=strictness_group, ) + add_invertible_flag( + "--disallow-str-iteration", + default=False, + strict_flag=False, + help="Disallow iterating over str instances", + group=strictness_group, + ) + add_invertible_flag( "--extra-checks", default=False, diff --git a/mypy/messages.py b/mypy/messages.py index 5863b8719b95..8ffe16c7e4ab 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1136,6 +1136,10 @@ def wrong_number_values_to_unpack( def unpacking_strings_disallowed(self, context: Context) -> None: self.fail("Unpacking a string is disallowed", context, code=codes.STR_UNPACK) + def str_iteration_disallowed(self, context: Context) -> None: + self.fail('Iterating over "str" is disallowed', context) + self.note("This is because --disallow-str-iteration is enabled", context) + def type_not_iterable(self, type: Type, context: Context) -> None: self.fail(f"{format_type(type, self.options)} object is not iterable", context) @@ -3122,6 +3126,15 @@ def get_conflict_protocol_types( Return them as a list of ('member', 'got', 'expected', 'is_lvalue'). """ assert right.type.is_protocol + + if left.type.fullname == "builtins.str" and right.type.fullname in ( + "collections.abc.Collection", + "typing.Collection", + ): + # `str` doesn't conform to the `Collection` protocol, but we don't want to show that as the reason for the error. + assert options.disallow_str_iteration + return [] + conflicts: list[tuple[str, Type, Type, bool]] = [] for member in right.type.protocol_members: if member in ("__init__", "__new__"): diff --git a/mypy/options.py b/mypy/options.py index cb5088af7e79..641a06ff74b6 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -37,6 +37,7 @@ class BuildType: "disallow_any_unimported", "disallow_incomplete_defs", "disallow_subclassing_any", + "disallow_str_iteration", "disallow_untyped_calls", "disallow_untyped_decorators", "disallow_untyped_defs", @@ -238,6 +239,9 @@ def __init__(self) -> None: # Disable treating bytearray and memoryview as subtypes of bytes self.strict_bytes = False + # Disallow iterating over str instances or using them as Sequence[T] + self.disallow_str_iteration = False + # Deprecated, use extra_checks instead. self.strict_concatenate = False diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 350d57a7e4ad..33982f70f474 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -479,6 +479,25 @@ def visit_instance(self, left: Instance) -> bool: # dynamic base classes correctly, see #5456. return not isinstance(self.right, NoneType) right = self.right + if ( + self.options + and self.options.disallow_str_iteration + and left.type.fullname == "builtins.str" + and isinstance(right, Instance) + and right.type.fullname + in ( + "collections.abc.Collection", + "collections.abc.Iterable", + "collections.abc.Reversible", + "collections.abc.Sequence", + "typing.Collection", + "typing.Iterable", + "typing.Reversible", + "typing.Sequence", + "_typeshed.SupportsLenAndGetItem", + ) + ): + return False if isinstance(right, TupleType) and right.partial_fallback.type.is_enum: return self._is_subtype(left, mypy.typeops.tuple_fallback(right)) if isinstance(right, TupleType): diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi index bd425ff3c212..e8b8676627d1 100644 --- a/mypy/typeshed/stdlib/builtins.pyi +++ b/mypy/typeshed/stdlib/builtins.pyi @@ -1455,6 +1455,8 @@ def input(prompt: object = "", /) -> str: ... class _GetItemIterable(Protocol[_T_co]): def __getitem__(self, i: int, /) -> _T_co: ... +@overload +def iter(object: str, /) -> Iterator[str]: ... @overload def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ... @overload diff --git a/test-data/unit/check-flags.test b/test-data/unit/check-flags.test index 8d18c699e628..7b1266917b50 100644 --- a/test-data/unit/check-flags.test +++ b/test-data/unit/check-flags.test @@ -2451,6 +2451,54 @@ f(bytearray(b"asdf")) # E: Argument 1 to "f" has incompatible type "bytearray"; f(memoryview(b"asdf")) # E: Argument 1 to "f" has incompatible type "memoryview"; expected "bytes" [builtins fixtures/primitives.pyi] +[case testDisallowStrIteration] +# flags: --disallow-str-iteration +from typing import Collection, Iterable, Reversible, Sequence, TypeVar + +def takes_str(x: str): + for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str" + [ch for ch in x] # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled + +s = "hello" + +def takes_seq_str(x: Sequence[str]) -> None: ... +takes_seq_str(s) # E: Argument 1 to "takes_seq_str" has incompatible type "str"; expected "Sequence[str]" + +def takes_iter_str(x: Iterable[str]) -> None: ... +takes_iter_str(s) # E: Argument 1 to "takes_iter_str" has incompatible type "str"; expected "Iterable[str]" + +def takes_collection_str(x: Collection[str]) -> None: ... +takes_collection_str(s) # E: Argument 1 to "takes_collection_str" has incompatible type "str"; expected "Collection[str]" + +def takes_reversible_str(x: Reversible[str]) -> None: ... +takes_reversible_str(s) # E: Argument 1 to "takes_reversible_str" has incompatible type "str"; expected "Reversible[str]" + +seq: Sequence[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Sequence[str]") +iterable: Iterable[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Iterable[str]") +collection: Collection[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Collection[str]") +reversible: Reversible[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Reversible[str]") + +def takes_maybe_seq(x: "str | Sequence[int]") -> None: + for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str | builtins.int" + +T = TypeVar('T', bound=str) + +def takes_str_upper_bound(x: T) -> None: + for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str" + +reveal_type(reversed(s)) # N: Revealed type is "builtins.reversed[builtins.str]" # E: Argument 1 to "reversed" has incompatible type "str"; expected "Reversible[str]" + +[builtins fixtures/str-iter.pyi] +[typing fixtures/typing-str-iter.pyi] + +[case testIterStrOverload] +# flags: --disallow-str-iteration +reveal_type(iter("foo")) # N: Revealed type is "typing.Iterable[builtins.str]" +[builtins fixtures/dict.pyi] + [case testNoCrashFollowImportsForStubs] # flags: --config-file tmp/mypy.ini {**{"x": "y"}} diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index ed2287511161..8fde41357a3d 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -61,4 +61,7 @@ class ellipsis: pass class BaseException: pass def isinstance(x: object, t: Union[type, Tuple[type, ...]]) -> bool: pass +@overload +def iter(__iterable: str) -> Iterable[str]: pass +@overload def iter(__iterable: Iterable[T]) -> Iterator[T]: pass diff --git a/test-data/unit/fixtures/str-iter.pyi b/test-data/unit/fixtures/str-iter.pyi new file mode 100644 index 000000000000..ac0822e708e1 --- /dev/null +++ b/test-data/unit/fixtures/str-iter.pyi @@ -0,0 +1,52 @@ +# Builtins stub used in disallow-str-iteration tests. + + +from _typeshed import SupportsLenAndGetItem +from typing import Generic, Iterator, Sequence, Reversible, TypeVar, overload + +_T = TypeVar("_T") + +class object: + def __init__(self) -> None: pass + +class type: pass +class int: pass +class bool(int): pass +class ellipsis: pass +class slice: pass + +class str: + def __iter__(self) -> Iterator[str]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + def __reversed__(self) -> Iterator[str]: pass + def __getitem__(self, i: int) -> str: pass + +class list(Sequence[_T], Generic[_T]): + def __iter__(self) -> Iterator[_T]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + def __reversed__(self) -> Iterator[_T]: pass + @overload + def __getitem__(self, i: int, /) -> _T: ... + @overload + def __getitem__(self, s: slice, /) -> list[_T]: ... + +class tuple(Sequence[_T], Generic[_T]): + def __iter__(self) -> Iterator[_T]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + def __reversed__(self) -> Iterator[_T]: pass + @overload + def __getitem__(self, i: int, /) -> _T: ... + @overload + def __getitem__(self, s: slice, /) -> list[_T]: ... + +class dict: pass + +class reversed(Iterator[_T]): + @overload + def __new__(cls, sequence: Reversible[_T], /) -> Iterator[_T]: ... # type: ignore[misc] + @overload + def __new__(cls, sequence: SupportsLenAndGetItem[_T], /) -> Iterator[_T]: ... # type: ignore[misc] + def __next__(self) -> _T: ... diff --git a/test-data/unit/fixtures/typing-str-iter.pyi b/test-data/unit/fixtures/typing-str-iter.pyi new file mode 100644 index 000000000000..c01b2df66d0c --- /dev/null +++ b/test-data/unit/fixtures/typing-str-iter.pyi @@ -0,0 +1,62 @@ +# Minimal typing fixture for disallow-str-iteration tests. + +from abc import ABCMeta, abstractmethod + +Any = object() +TypeVar = 0 +Generic = 0 +Protocol = 0 +overload = 0 + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_T_co = TypeVar("_T_co", covariant=True) +_VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers. +_TC = TypeVar("_TC", bound=type[object]) + +@runtime_checkable +class Iterable(Protocol[_T_co]): + @abstractmethod + def __iter__(self) -> Iterator[_T_co]: ... + +@runtime_checkable +class Iterator(Iterable[_T_co], Protocol[_T_co]): + @abstractmethod + def __next__(self) -> _T_co: ... + def __iter__(self) -> Iterator[_T_co]: ... + +@runtime_checkable +class Reversible(Iterable[_T_co], Protocol[_T_co]): + @abstractmethod + def __reversed__(self) -> Iterator[_T_co]: ... + +@runtime_checkable +class Container(Protocol[_T_co]): + # This is generic more on vibes than anything else + @abstractmethod + def __contains__(self, x: object, /) -> bool: ... + +@runtime_checkable +class Collection(Iterable[_T_co], Container[_T_co], Protocol[_T_co]): + # Implement Sized (but don't have it as a base class). + @abstractmethod + def __len__(self) -> int: ... + +class Sequence(Reversible[_T_co], Collection[_T_co]): + @overload + @abstractmethod + def __getitem__(self, index: int) -> _T_co: ... + @overload + @abstractmethod + def __getitem__(self, index: slice) -> Sequence[_T_co]: ... + def __contains__(self, value: object) -> bool: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __reversed__(self) -> Iterator[_T_co]: ... + +class Mapping(Collection[_KT], Generic[_KT, _VT_co]): + @abstractmethod + def __getitem__(self, key: _KT, /) -> _VT_co: ... + def __contains__(self, key: object, /) -> bool: ... + +def runtime_checkable(cls: _TC) -> _TC: + return cls diff --git a/test-data/unit/lib-stub/_typeshed.pyi b/test-data/unit/lib-stub/_typeshed.pyi index 054ad0ec0c46..87736007a36a 100644 --- a/test-data/unit/lib-stub/_typeshed.pyi +++ b/test-data/unit/lib-stub/_typeshed.pyi @@ -2,7 +2,12 @@ from typing import Protocol, TypeVar, Iterable _KT = TypeVar("_KT") _VT_co = TypeVar("_VT_co", covariant=True) +_T_co = TypeVar("_T_co", covariant=True) class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: pass def __getitem__(self, __key: _KT) -> _VT_co: pass + +class SupportsLenAndGetItem(Protocol[_T_co]): + def __len__(self) -> int: pass + def __getitem__(self, k: int, /) -> _T_co: pass