diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 1643d536..e6c20d54 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -19,7 +19,8 @@ import functools import inspect import traceback -from typing import Any, Callable, List, Optional, Sequence, Set, Union, cast +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union, cast import unittest from unittest import mock @@ -539,7 +540,7 @@ def assert_equal_shape_suffix(inputs: Sequence[Array], suffix_len: int) -> None: def _unelided_shape_matches( actual_shape: Sequence[int], - expected_shape: Sequence[Optional[Union[int, Set[int]]]]) -> bool: + expected_shape: Sequence[Optional[Union[int, set[int]]]]) -> bool: """Returns True if `actual_shape` is compatible with `expected_shape`.""" if len(actual_shape) != len(expected_shape): return False @@ -558,8 +559,8 @@ def _shape_matches(actual_shape: Sequence[int], expected_shape: _ai.TShapeMatcher) -> bool: """Returns True if `actual_shape` is compatible with `expected_shape`.""" # Splits `expected_shape` based on the position of the ellipsis, if present. - expected_prefix: List[_ai.TDimMatcher] = [] - expected_suffix: Optional[List[_ai.TDimMatcher]] = None + expected_prefix: list[_ai.TDimMatcher] = [] + expected_suffix: Optional[list[_ai.TDimMatcher]] = None for dim in expected_shape: if dim is Ellipsis: if expected_suffix is not None: @@ -701,8 +702,8 @@ def assert_equal_rank(inputs: Sequence[Array]) -> None: @_static_assertion def assert_rank( inputs: Union[Scalar, Union[Array, Sequence[Array]]], - expected_ranks: Union[int, Set[int], Sequence[Union[int, - Set[int]]]]) -> None: + expected_ranks: Union[int, set[int], Sequence[Union[int, + set[int]]]]) -> None: """Checks that the rank of all inputs matches specified ``expected_ranks``. Valid usages include: diff --git a/chex/_src/asserts_chexify.py b/chex/_src/asserts_chexify.py index 6df66021..e01811f4 100644 --- a/chex/_src/asserts_chexify.py +++ b/chex/_src/asserts_chexify.py @@ -20,7 +20,8 @@ import dataclasses import functools import re -from typing import Any, Callable, FrozenSet +from collections.abc import Callable +from typing import Any from absl import logging from chex._src import asserts_internal as _ai @@ -32,13 +33,13 @@ class _ChexifyChecks: """A set of checks imported from checkify.""" - user: FrozenSet[checkify.ErrorCategory] = checkify.user_checks - nan: FrozenSet[checkify.ErrorCategory] = checkify.nan_checks - index: FrozenSet[checkify.ErrorCategory] = checkify.index_checks - div: FrozenSet[checkify.ErrorCategory] = checkify.div_checks - float: FrozenSet[checkify.ErrorCategory] = checkify.float_checks - automatic: FrozenSet[checkify.ErrorCategory] = checkify.automatic_checks - all: FrozenSet[checkify.ErrorCategory] = checkify.all_checks + user: frozenset[checkify.ErrorCategory] = checkify.user_checks + nan: frozenset[checkify.ErrorCategory] = checkify.nan_checks + index: frozenset[checkify.ErrorCategory] = checkify.index_checks + div: frozenset[checkify.ErrorCategory] = checkify.div_checks + float: frozenset[checkify.ErrorCategory] = checkify.float_checks + automatic: frozenset[checkify.ErrorCategory] = checkify.automatic_checks + all: frozenset[checkify.ErrorCategory] = checkify.all_checks _chexify_error_pattern = re.compile( @@ -89,7 +90,7 @@ def _check_if_hanging_assertions(): def chexify( fn: Callable[..., Any], async_check: bool = True, - errors: FrozenSet[checkify.ErrorCategory] = ChexifyChecks.user, + errors: frozenset[checkify.ErrorCategory] = ChexifyChecks.user, ) -> Callable[..., Any]: """Wraps a transformed function `fn` to enable Chex value assertions. diff --git a/chex/_src/asserts_chexify_test.py b/chex/_src/asserts_chexify_test.py index aa8baae3..d5fa54f8 100644 --- a/chex/_src/asserts_chexify_test.py +++ b/chex/_src/asserts_chexify_test.py @@ -19,7 +19,8 @@ import sys import threading import time -from typing import Any, Optional, Sequence, Type +from collections.abc import Sequence +from typing import Any, Optional from absl.testing import absltest from absl.testing import parameterized @@ -60,7 +61,7 @@ def _assert_noop(*args, custom_message: Optional[str] = None, custom_message_format_vars: Sequence[Any] = (), include_default_message: bool = True, - exception_type: Type[Exception] = AssertionError, + exception_type: type[Exception] = AssertionError, **kwargs) -> None: """No-op.""" del args, custom_message, custom_message_format_vars diff --git a/chex/_src/asserts_internal.py b/chex/_src/asserts_internal.py index 2f11dd3f..835e434c 100644 --- a/chex/_src/asserts_internal.py +++ b/chex/_src/asserts_internal.py @@ -29,7 +29,8 @@ import re import threading import traceback -from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union from absl import logging from chex._src import pytypes @@ -56,7 +57,7 @@ TJittableAssertFn = Callable[..., pytypes.Array] # a predicate function # Matchers. -TDimMatcher = Optional[Union[int, Set[int], type(Ellipsis)]] +TDimMatcher = Optional[Union[int, set[int], type(Ellipsis)]] TShapeMatcher = Sequence[TDimMatcher] @@ -101,7 +102,7 @@ def inner_fn(*args, **kwargs): return inner_fn -def get_stacktrace_without_chex_internals() -> List[traceback.FrameSummary]: +def get_stacktrace_without_chex_internals() -> list[traceback.FrameSummary]: """Returns the latest non-chex frame from the call stack.""" stacktrace = list(traceback.extract_stack()) for i in reversed(range(len(stacktrace))): @@ -160,7 +161,7 @@ def _assert_on_host(*args, custom_message: Optional[str] = None, custom_message_format_vars: Sequence[Any] = (), include_default_message: bool = True, - exception_type: Type[Exception] = AssertionError, + exception_type: type[Exception] = AssertionError, **kwargs) -> None: # Format error's stack trace to remove Chex' internal frames. assertion_exc = None @@ -233,7 +234,7 @@ def _chex_assert_fn(*args, custom_message: Optional[str] = None, custom_message_format_vars: Sequence[Any] = (), include_default_message: bool = True, - exception_type: Type[Exception] = AssertionError, + exception_type: type[Exception] = AssertionError, **kwargs) -> None: if DISABLE_ASSERTIONS: return @@ -314,7 +315,7 @@ def num_devices_available(devtype: str, backend: Optional[str] = None) -> int: return sum(d.platform == devtype for d in jax.devices(backend)) -def get_tracers(tree: pytypes.ArrayTree) -> Tuple[jax.core.Tracer]: +def get_tracers(tree: pytypes.ArrayTree) -> tuple[jax.core.Tracer]: """Returns a tuple with tracers from a tree.""" return tuple( x for x in jax.tree_util.tree_leaves(tree) @@ -408,7 +409,7 @@ def assert_trees_all_eq_comparator_jittable( "forgot the `error_msg_fn` arg to `assert_trees_xxx`?") def _tree_error_msg_fn( - path: Tuple[Union[int, str, Hashable]], i_1: int, i_2: int): + path: tuple[Union[int, str, Hashable]], i_1: int, i_2: int): if path: return ( f"Trees {i_1} and {i_2} differ in leaves '{path}':" @@ -458,7 +459,7 @@ def _cmp_leaves(path, *leaves): def convert_jax_path_to_dm_path( jax_tree_path: Sequence[JaxKeyType], -) -> Tuple[Union[int, str, Hashable]]: +) -> tuple[Union[int, str, Hashable]]: """Converts a path from jax.tree_util to one from dm-tree.""" # pytype:disable=attribute-error diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index f981c8a2..29105404 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -20,7 +20,8 @@ import dataclasses import pickle import sys -from typing import Any, Generic, Mapping, TypeVar +from collections.abc import Mapping +from typing import Any, Generic, TypeVar import unittest from absl.testing import absltest diff --git a/chex/_src/dimensions.py b/chex/_src/dimensions.py index 695c5beb..b4ec51ab 100644 --- a/chex/_src/dimensions.py +++ b/chex/_src/dimensions.py @@ -17,10 +17,11 @@ from collections.abc import Sized import math import re -from typing import Any, Collection, Dict, Optional, Tuple +from collections.abc import Collection +from typing import Any, Optional -Shape = Tuple[Optional[int], ...] +Shape = tuple[Optional[int], ...] class Dimensions: @@ -198,7 +199,7 @@ def __repr__(self) -> str: args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items())) return f'{type(self).__name__}({args})' - def _asdict(self) -> Dict[str, Optional[int]]: + def _asdict(self) -> dict[str, Optional[int]]: return {k: v for k, v in self.__dict__.items() if re.fullmatch(r'[a-zA-Z]', k)} diff --git a/chex/_src/fake.py b/chex/_src/fake.py index 179e6f7f..07fa86cb 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -25,7 +25,8 @@ import inspect import os import re -from typing import Any, Callable, Iterable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union from unittest import mock from absl import flags import jax diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index c6da0777..ef9ad046 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,7 +14,8 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Iterable, Mapping, Sequence, Union +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Union import jax import numpy as np diff --git a/chex/_src/restrict_backends.py b/chex/_src/restrict_backends.py index 19937d74..cb747e2d 100644 --- a/chex/_src/restrict_backends.py +++ b/chex/_src/restrict_backends.py @@ -29,7 +29,8 @@ """ import contextlib import functools -from typing import Callable, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import Optional from jax._src import compiler diff --git a/chex/_src/variants.py b/chex/_src/variants.py index 6a5ddec6..142fb989 100644 --- a/chex/_src/variants.py +++ b/chex/_src/variants.py @@ -18,7 +18,8 @@ import functools import inspect import itertools -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import unittest from absl import flags diff --git a/chex/_src/warnings.py b/chex/_src/warnings.py index 28704d7c..649b88ca 100644 --- a/chex/_src/warnings.py +++ b/chex/_src/warnings.py @@ -15,7 +15,8 @@ """Utilities to emit warnings.""" import functools -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import warnings diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index 33c1dd23..ae3d29a6 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -16,7 +16,8 @@ import inspect import types -from typing import Any, Mapping, Sequence, Tuple +from collections.abc import Mapping, Sequence +from typing import Any import chex as _module from sphinx import application @@ -25,7 +26,7 @@ def find_internal_python_modules( - root_module: types.ModuleType,) -> Sequence[Tuple[str, types.ModuleType]]: + root_module: types.ModuleType,) -> Sequence[tuple[str, types.ModuleType]]: """Returns `(name, module)` for all submodules under `root_module`.""" modules = set([(root_module.__name__, root_module)]) visited = set()