From 3ec7e28720255f921e32eb0facb0f3a9c72621c3 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:10:17 -0700 Subject: [PATCH 1/6] [Add] Pyright stubs for quadrants types and NDArray subscript syntax Add .pyi stub files for primitive_types, ndarray_type, utils, annotations, and compound_types so that Pyright/mypy can type-check code using quadrants types. Update .gitignore to only ignore generated stubs in stubs/quadrants/_lib/. Add pyright test for primitive dtype annotations and NDArray subscript syntax. --- .gitignore | 2 +- stubs/quadrants/types/annotations.pyi | 13 ++++ stubs/quadrants/types/compound_types.pyi | 10 +++ stubs/quadrants/types/ndarray_type.pyi | 27 ++++++++ stubs/quadrants/types/primitive_types.pyi | 69 ++++++++++++++++++ stubs/quadrants/types/utils.pyi | 6 ++ tests/python/pyright/test_ndarray_type.py | 21 ++++-- tests/python/pyright/test_primitive_types.py | 73 ++++++++++++++++++++ 8 files changed, 215 insertions(+), 6 deletions(-) create mode 100644 stubs/quadrants/types/annotations.pyi create mode 100644 stubs/quadrants/types/compound_types.pyi create mode 100644 stubs/quadrants/types/ndarray_type.pyi create mode 100644 stubs/quadrants/types/primitive_types.pyi create mode 100644 stubs/quadrants/types/utils.pyi create mode 100644 tests/python/pyright/test_primitive_types.py diff --git a/.gitignore b/.gitignore index 6d05d1ed47..8a88637127 100644 --- a/.gitignore +++ b/.gitignore @@ -93,6 +93,6 @@ imgui.ini !pyrightconfig.json *.whl *.so -stubs/ +stubs/quadrants/_lib/ CHANGELOG.md python/quadrants/_version.py diff --git a/stubs/quadrants/types/annotations.pyi b/stubs/quadrants/types/annotations.pyi new file mode 100644 index 0000000000..848e969ebb --- /dev/null +++ b/stubs/quadrants/types/annotations.pyi @@ -0,0 +1,13 @@ +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + +class Template(Generic[T]): + element_type: type[T] + ndim: int | None + def __init__(self, element_type: type[T] = ..., ndim: int | None = ...) -> None: ... + def __getitem__(self, i: Any) -> T: ... + +template = Template + +class sparse_matrix_builder: ... diff --git a/stubs/quadrants/types/compound_types.pyi b/stubs/quadrants/types/compound_types.pyi new file mode 100644 index 0000000000..f834646c42 --- /dev/null +++ b/stubs/quadrants/types/compound_types.pyi @@ -0,0 +1,10 @@ +from typing import Any + +class CompoundType: + def from_kernel_struct_ret(self, launch_ctx: Any, index: tuple[Any, ...]) -> Any: ... + def check_matched(self, other: Any) -> bool: ... + def to_string(self) -> str: ... + +def matrix(n: int | None = ..., m: int | None = ..., dtype: Any = ...) -> Any: ... +def vector(n: int | None = ..., dtype: Any = ...) -> Any: ... +def struct(**kwargs: Any) -> Any: ... diff --git a/stubs/quadrants/types/ndarray_type.pyi b/stubs/quadrants/types/ndarray_type.pyi new file mode 100644 index 0000000000..22d1b269a7 --- /dev/null +++ b/stubs/quadrants/types/ndarray_type.pyi @@ -0,0 +1,27 @@ +from typing import Any + +class NdarrayType: + dtype: Any + ndim: int | None + needs_grad: bool | None + boundary: int + + def __init__( + self, + dtype: Any = ..., + ndim: int | None = ..., + element_dim: int | None = ..., + element_shape: tuple[int, ...] | None = ..., + field_dim: int | None = ..., + needs_grad: bool | None = ..., + boundary: str = ..., + ) -> None: ... + @classmethod + def __class_getitem__(cls, args: Any) -> type[NdarrayType]: ... + def __getitem__(self, i: Any) -> Any: ... + def __setitem__(self, i: Any, v: Any) -> None: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + +ndarray = NdarrayType +NDArray = NdarrayType diff --git a/stubs/quadrants/types/primitive_types.pyi b/stubs/quadrants/types/primitive_types.pyi new file mode 100644 index 0000000000..9191ec125e --- /dev/null +++ b/stubs/quadrants/types/primitive_types.pyi @@ -0,0 +1,69 @@ +from typing import Any, ClassVar, Union + +from quadrants._lib.core.quadrants_python import DataTypeCxx + +class PrimitiveMeta(type): + cxx: DataTypeCxx + def __eq__(cls, other: object) -> bool: ... + def __ne__(cls, other: object) -> bool: ... + def __hash__(cls) -> int: ... + def __repr__(cls) -> str: ... + def __getattr__(cls, name: str) -> Any: ... + +class PrimitiveBase(metaclass=PrimitiveMeta): + cxx: ClassVar[DataTypeCxx] + +class f16(PrimitiveBase): ... +class f32(PrimitiveBase): ... +class f64(PrimitiveBase): ... +class i8(PrimitiveBase): ... +class i16(PrimitiveBase): ... +class i32(PrimitiveBase): ... +class i64(PrimitiveBase): ... +class u1(PrimitiveBase): ... +class u8(PrimitiveBase): ... +class u16(PrimitiveBase): ... +class u32(PrimitiveBase): ... +class u64(PrimitiveBase): ... + +float16 = f16 +float32 = f32 +float64 = f64 +int8 = i8 +int16 = i16 +int32 = i32 +int64 = i64 +uint1 = u1 +uint8 = u8 +uint16 = u16 +uint32 = u32 +uint64 = u64 + +# Raw C++ DataType instances (internal use) +f16_cxx: DataTypeCxx +f32_cxx: DataTypeCxx +f64_cxx: DataTypeCxx +i8_cxx: DataTypeCxx +i16_cxx: DataTypeCxx +i32_cxx: DataTypeCxx +i64_cxx: DataTypeCxx +u1_cxx: DataTypeCxx +u8_cxx: DataTypeCxx +u16_cxx: DataTypeCxx +u32_cxx: DataTypeCxx +u64_cxx: DataTypeCxx + +class RefType: + tp: Any + def __init__(self, tp: Any) -> None: ... + +def ref(tp: Any) -> RefType: ... + +real_types: set[type[PrimitiveBase] | type] +real_type_ids: set[int] +integer_types: set[type[PrimitiveBase] | type] +integer_type_ids: set[int] +all_types: set[type[PrimitiveBase] | type] +cxx_type_ids: set[int] +type_ids: set[int] +_python_primitive_types = Union[int, float, bool, str, None] diff --git a/stubs/quadrants/types/utils.pyi b/stubs/quadrants/types/utils.pyi new file mode 100644 index 0000000000..6758391c78 --- /dev/null +++ b/stubs/quadrants/types/utils.pyi @@ -0,0 +1,6 @@ +from typing import Any + +def is_signed(dt: Any) -> bool: ... +def is_integral(dt: Any) -> bool: ... +def is_real(dt: Any) -> bool: ... +def is_tensor(dt: Any) -> bool: ... diff --git a/tests/python/pyright/test_ndarray_type.py b/tests/python/pyright/test_ndarray_type.py index f728a99f69..4bba59de58 100644 --- a/tests/python/pyright/test_ndarray_type.py +++ b/tests/python/pyright/test_ndarray_type.py @@ -1,4 +1,4 @@ -# This is a test file. It just has to exist, to check that pyright works with it. +# Pyright test: NDArray annotations accepted by the type checker and functional at runtime. import quadrants as qd @@ -7,21 +7,30 @@ qd.init(arch=qd.cpu) +# Legacy call syntax (still works, pyright warns about call expressions in type positions) @qd.kernel -def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... +def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] @qd.kernel() -def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... +def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] + + +# New subscript syntax (preferred, no pyright warnings) +@qd.kernel +def k3(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ... @qd.data_oriented class SomeClass: @qd.kernel - def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... + def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] @qd.kernel() - def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... + def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] + + @qd.kernel + def k3(self, a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ... @test_utils.test() @@ -29,7 +38,9 @@ def test_ndarray_type(): a = qd.ndarray(qd.i32, (10,)) k1(a, a, a) k2(a, a, a) + k3(a, a, a) some_class = SomeClass() some_class.k1(a, a, a) some_class.k2(a, a, a) + some_class.k3(a, a, a) diff --git a/tests/python/pyright/test_primitive_types.py b/tests/python/pyright/test_primitive_types.py new file mode 100644 index 0000000000..4c618ab07e --- /dev/null +++ b/tests/python/pyright/test_primitive_types.py @@ -0,0 +1,73 @@ +"""Pyright test: primitive dtype classes and NDArray subscript syntax. + +This file must produce zero pyright errors. It validates that: +- Primitive dtypes (f32, i32, etc.) work as type annotations +- NDArray[dtype, ndim] subscript syntax is accepted +- from __future__ import annotations (stringified) works +- NDArray works via qd.types and top-level qd.NDArray +- Return types and Optional wrappers work +""" + +from __future__ import annotations + +from typing import Optional + +import quadrants as qd + + +# Primitive types as annotations +def accept_f32(x: qd.f32) -> None: ... + + +def accept_i32(x: qd.i32) -> None: ... + + +def accept_any_dtype(x: qd.f32 | qd.i32 | qd.u8) -> None: ... + + +# NDArray subscript: dtype + ndim +def kernel_2d(a: qd.types.NDArray[qd.f32, 2]) -> None: ... + + +# NDArray subscript: dtype only +def kernel_dtype(a: qd.types.NDArray[qd.i32]) -> None: ... + + +# NDArray bare (no subscript) +def kernel_bare(a: qd.types.NDArray) -> None: ... + + +# Multiple NDArray args with different types +def multi_args( + a: qd.types.NDArray[qd.f32, 2], + b: qd.types.NDArray[qd.i32, 1], + c: qd.types.NDArray, +) -> None: ... + + +# Top-level NDArray alias (accessible via qd.types) +def top_level(a: qd.types.NDArray[qd.f32, 2]) -> None: ... + + +# Return types +def make_arr() -> qd.types.NDArray[qd.f32, 2]: ... + + +# Optional wrapping +def maybe_arr(x: Optional[qd.types.NDArray[qd.f32, 2]]) -> None: ... + + +# Variable annotations +field1: qd.types.NDArray[qd.f32, 2] +field2: qd.types.NDArray + + +# In class body +class MyModel: + buf: qd.types.NDArray[qd.f32, 3] + + def forward(self, x: qd.types.NDArray[qd.f32, 2]) -> qd.types.NDArray[qd.f32, 2]: ... + + +# Access via qd.types submodule +def via_types(x: qd.types.NDArray[qd.types.f32, 2]) -> None: ... From 659b161bd88af96cca7662d19d47a3f596021436 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 15:55:51 -0700 Subject: [PATCH 2/6] [Lint] Move imports to top level to satisfy pylint C0415 Moves `inspect` import in `lang/exception.py` and `get_func_signature` import in `lang/_kernel_impl_dataclass.py` from function bodies up to module level, fixing the `import-outside-toplevel` lint violations. Made-with: Cursor --- python/quadrants/lang/_kernel_impl_dataclass.py | 3 +-- python/quadrants/lang/exception.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index ebd298dff5..baec2c68ab 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -7,6 +7,7 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) +from quadrants.lang.exception import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata @@ -72,8 +73,6 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - from quadrants.lang.exception import get_func_signature - sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index beaf2eeb0e..8db8c5bffc 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,5 +1,7 @@ # type: ignore +import inspect + from quadrants._lib import core @@ -59,8 +61,6 @@ def get_ret(needed, provided): def get_func_signature(func): """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" - import inspect - try: return inspect.signature(func, eval_str=True) except (NameError, AttributeError) as e: From 87e848c9a05855c23988b59eaa6d0f24392bd505 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 16:40:43 -0700 Subject: [PATCH 3/6] [Test] Require data64 for f64 dual field debug-mode regression test Mac vulkan/metal backends don't support f64, causing worker crashes. Made-with: Cursor --- tests/python/test_ad_basics_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index 760db9a7c4..9dcbd1f060 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -126,7 +126,7 @@ def clear_dual_test(): assert y.dual[None] == 4.0 -@test_utils.test(debug=True) +@test_utils.test(require=qd.extension.data64, debug=True) def test_dual_field_dtype_preserved_in_debug_mode(): """Regression: debug-mode checkbit must not shadow the outer dtype.""" x = qd.field(qd.f64, shape=(), needs_dual=True) From 2469e06ad22c30b74499fe2caae6ac6447395f53 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 21:17:47 -0700 Subject: [PATCH 4/6] [Fix] CI: pyright PrimitiveMeta.__call__ stub and DataTypeCxx in test_shared_array - Add __call__ to PrimitiveMeta in stubs so pyright accepts qd.u32(j) / qd.i32(...) casts in kernel code (fixes 8 reportCallIssue errors in lang/simt/_tile16.py). - In test_shared_array.py, pass DataTypeCxx (via cook_dtype) to data_type_size since primitive types are now Python classes. Made-with: Cursor --- stubs/quadrants/types/primitive_types.pyi | 1 + tests/python/test_shared_array.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/stubs/quadrants/types/primitive_types.pyi b/stubs/quadrants/types/primitive_types.pyi index 9191ec125e..2029547b19 100644 --- a/stubs/quadrants/types/primitive_types.pyi +++ b/stubs/quadrants/types/primitive_types.pyi @@ -4,6 +4,7 @@ from quadrants._lib.core.quadrants_python import DataTypeCxx class PrimitiveMeta(type): cxx: DataTypeCxx + def __call__(cls, *args: Any, **kwargs: Any) -> Any: ... def __eq__(cls, other: object) -> bool: ... def __ne__(cls, other: object) -> bool: ... def __hash__(cls) -> int: ... diff --git a/tests/python/test_shared_array.py b/tests/python/test_shared_array.py index 6999075d6e..b7ec7cc60b 100644 --- a/tests/python/test_shared_array.py +++ b/tests/python/test_shared_array.py @@ -57,8 +57,8 @@ def test_shared_array_not_accumulated_across_offloads(num_dim, first_shape_delta max_shared_bytes = qd.lang.impl.get_max_shared_memory_bytes(is_lowerbound_ok=True) # 75% of max shared memory in bytes, converted to element counts shared_array_bytes = int(0.75 * max_shared_bytes) - num_elems_1 = shared_array_bytes // qd._lib.core.data_type_size(dtype1) + first_shape_delta_size - num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(dtype2) + num_elems_1 = shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype1)) + first_shape_delta_size + num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype2)) # Build 1D or 2D shape tuples with the same total number of elements. # For 2D, split into (block_dim, num_elems // block_dim). From c3558048c0f242595f6b9c73a62511e3065418d1 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 22:14:06 -0700 Subject: [PATCH 5/6] Apply black formatting to test_shared_array.py Made-with: Cursor --- tests/python/test_shared_array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/test_shared_array.py b/tests/python/test_shared_array.py index b7ec7cc60b..e6ade1739b 100644 --- a/tests/python/test_shared_array.py +++ b/tests/python/test_shared_array.py @@ -57,7 +57,9 @@ def test_shared_array_not_accumulated_across_offloads(num_dim, first_shape_delta max_shared_bytes = qd.lang.impl.get_max_shared_memory_bytes(is_lowerbound_ok=True) # 75% of max shared memory in bytes, converted to element counts shared_array_bytes = int(0.75 * max_shared_bytes) - num_elems_1 = shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype1)) + first_shape_delta_size + num_elems_1 = ( + shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype1)) + first_shape_delta_size + ) num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype2)) # Build 1D or 2D shape tuples with the same total number of elements. From 7bc61313b8b2d4050ff9a73d247851fe945caef2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 03:54:41 -0700 Subject: [PATCH 6/6] [Lint] Remove duplicate get_func_signature import from exception module Made-with: Cursor --- python/quadrants/lang/_kernel_impl_dataclass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index 18a7ae78a1..fe862cfa3e 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -8,7 +8,6 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) -from quadrants.lang.exception import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata