diff --git a/.gitignore b/.gitignore index 0aaf63e31c..2b230d5b72 100644 --- a/.gitignore +++ b/.gitignore @@ -93,7 +93,7 @@ imgui.ini !pyrightconfig.json *.whl *.so -stubs/ +stubs/quadrants/_lib/ CHANGELOG.md python/quadrants/_version.py env.sh diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index 771dd56b02..8db8c5bffc 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,5 +1,7 @@ # type: ignore +import inspect + from quadrants._lib import core @@ -57,6 +59,14 @@ def get_ret(needed, provided): return QuadrantsRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}") +def get_func_signature(func): + """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" + try: + return inspect.signature(func, eval_str=True) + except (NameError, AttributeError) as e: + raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e + + def handle_exception_from_cpp(exc): if isinstance(exc, core.QuadrantsTypeError): return QuadrantsTypeError(str(exc)) diff --git a/stubs/quadrants/types/annotations.pyi b/stubs/quadrants/types/annotations.pyi new file mode 100644 index 0000000000..848e969ebb --- /dev/null +++ b/stubs/quadrants/types/annotations.pyi @@ -0,0 +1,13 @@ +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + +class Template(Generic[T]): + element_type: type[T] + ndim: int | None + def __init__(self, element_type: type[T] = ..., ndim: int | None = ...) -> None: ... + def __getitem__(self, i: Any) -> T: ... + +template = Template + +class sparse_matrix_builder: ... diff --git a/stubs/quadrants/types/compound_types.pyi b/stubs/quadrants/types/compound_types.pyi new file mode 100644 index 0000000000..f834646c42 --- /dev/null +++ b/stubs/quadrants/types/compound_types.pyi @@ -0,0 +1,10 @@ +from typing import Any + +class CompoundType: + def from_kernel_struct_ret(self, launch_ctx: Any, index: tuple[Any, ...]) -> Any: ... + def check_matched(self, other: Any) -> bool: ... + def to_string(self) -> str: ... + +def matrix(n: int | None = ..., m: int | None = ..., dtype: Any = ...) -> Any: ... +def vector(n: int | None = ..., dtype: Any = ...) -> Any: ... +def struct(**kwargs: Any) -> Any: ... diff --git a/stubs/quadrants/types/ndarray_type.pyi b/stubs/quadrants/types/ndarray_type.pyi new file mode 100644 index 0000000000..22d1b269a7 --- /dev/null +++ b/stubs/quadrants/types/ndarray_type.pyi @@ -0,0 +1,27 @@ +from typing import Any + +class NdarrayType: + dtype: Any + ndim: int | None + needs_grad: bool | None + boundary: int + + def __init__( + self, + dtype: Any = ..., + ndim: int | None = ..., + element_dim: int | None = ..., + element_shape: tuple[int, ...] | None = ..., + field_dim: int | None = ..., + needs_grad: bool | None = ..., + boundary: str = ..., + ) -> None: ... + @classmethod + def __class_getitem__(cls, args: Any) -> type[NdarrayType]: ... + def __getitem__(self, i: Any) -> Any: ... + def __setitem__(self, i: Any, v: Any) -> None: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + +ndarray = NdarrayType +NDArray = NdarrayType diff --git a/stubs/quadrants/types/primitive_types.pyi b/stubs/quadrants/types/primitive_types.pyi new file mode 100644 index 0000000000..2029547b19 --- /dev/null +++ b/stubs/quadrants/types/primitive_types.pyi @@ -0,0 +1,70 @@ +from typing import Any, ClassVar, Union + +from quadrants._lib.core.quadrants_python import DataTypeCxx + +class PrimitiveMeta(type): + cxx: DataTypeCxx + def __call__(cls, *args: Any, **kwargs: Any) -> Any: ... + def __eq__(cls, other: object) -> bool: ... + def __ne__(cls, other: object) -> bool: ... + def __hash__(cls) -> int: ... + def __repr__(cls) -> str: ... + def __getattr__(cls, name: str) -> Any: ... + +class PrimitiveBase(metaclass=PrimitiveMeta): + cxx: ClassVar[DataTypeCxx] + +class f16(PrimitiveBase): ... +class f32(PrimitiveBase): ... +class f64(PrimitiveBase): ... +class i8(PrimitiveBase): ... +class i16(PrimitiveBase): ... +class i32(PrimitiveBase): ... +class i64(PrimitiveBase): ... +class u1(PrimitiveBase): ... +class u8(PrimitiveBase): ... +class u16(PrimitiveBase): ... +class u32(PrimitiveBase): ... +class u64(PrimitiveBase): ... + +float16 = f16 +float32 = f32 +float64 = f64 +int8 = i8 +int16 = i16 +int32 = i32 +int64 = i64 +uint1 = u1 +uint8 = u8 +uint16 = u16 +uint32 = u32 +uint64 = u64 + +# Raw C++ DataType instances (internal use) +f16_cxx: DataTypeCxx +f32_cxx: DataTypeCxx +f64_cxx: DataTypeCxx +i8_cxx: DataTypeCxx +i16_cxx: DataTypeCxx +i32_cxx: DataTypeCxx +i64_cxx: DataTypeCxx +u1_cxx: DataTypeCxx +u8_cxx: DataTypeCxx +u16_cxx: DataTypeCxx +u32_cxx: DataTypeCxx +u64_cxx: DataTypeCxx + +class RefType: + tp: Any + def __init__(self, tp: Any) -> None: ... + +def ref(tp: Any) -> RefType: ... + +real_types: set[type[PrimitiveBase] | type] +real_type_ids: set[int] +integer_types: set[type[PrimitiveBase] | type] +integer_type_ids: set[int] +all_types: set[type[PrimitiveBase] | type] +cxx_type_ids: set[int] +type_ids: set[int] +_python_primitive_types = Union[int, float, bool, str, None] diff --git a/stubs/quadrants/types/utils.pyi b/stubs/quadrants/types/utils.pyi new file mode 100644 index 0000000000..6758391c78 --- /dev/null +++ b/stubs/quadrants/types/utils.pyi @@ -0,0 +1,6 @@ +from typing import Any + +def is_signed(dt: Any) -> bool: ... +def is_integral(dt: Any) -> bool: ... +def is_real(dt: Any) -> bool: ... +def is_tensor(dt: Any) -> bool: ... diff --git a/tests/python/pyright/test_ndarray_type.py b/tests/python/pyright/test_ndarray_type.py index f728a99f69..4bba59de58 100644 --- a/tests/python/pyright/test_ndarray_type.py +++ b/tests/python/pyright/test_ndarray_type.py @@ -1,4 +1,4 @@ -# This is a test file. It just has to exist, to check that pyright works with it. +# Pyright test: NDArray annotations accepted by the type checker and functional at runtime. import quadrants as qd @@ -7,21 +7,30 @@ qd.init(arch=qd.cpu) +# Legacy call syntax (still works, pyright warns about call expressions in type positions) @qd.kernel -def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... +def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] @qd.kernel() -def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... +def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] + + +# New subscript syntax (preferred, no pyright warnings) +@qd.kernel +def k3(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ... @qd.data_oriented class SomeClass: @qd.kernel - def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... + def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] @qd.kernel() - def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... + def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm] + + @qd.kernel + def k3(self, a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ... @test_utils.test() @@ -29,7 +38,9 @@ def test_ndarray_type(): a = qd.ndarray(qd.i32, (10,)) k1(a, a, a) k2(a, a, a) + k3(a, a, a) some_class = SomeClass() some_class.k1(a, a, a) some_class.k2(a, a, a) + some_class.k3(a, a, a) diff --git a/tests/python/pyright/test_primitive_types.py b/tests/python/pyright/test_primitive_types.py new file mode 100644 index 0000000000..4c618ab07e --- /dev/null +++ b/tests/python/pyright/test_primitive_types.py @@ -0,0 +1,73 @@ +"""Pyright test: primitive dtype classes and NDArray subscript syntax. + +This file must produce zero pyright errors. It validates that: +- Primitive dtypes (f32, i32, etc.) work as type annotations +- NDArray[dtype, ndim] subscript syntax is accepted +- from __future__ import annotations (stringified) works +- NDArray works via qd.types and top-level qd.NDArray +- Return types and Optional wrappers work +""" + +from __future__ import annotations + +from typing import Optional + +import quadrants as qd + + +# Primitive types as annotations +def accept_f32(x: qd.f32) -> None: ... + + +def accept_i32(x: qd.i32) -> None: ... + + +def accept_any_dtype(x: qd.f32 | qd.i32 | qd.u8) -> None: ... + + +# NDArray subscript: dtype + ndim +def kernel_2d(a: qd.types.NDArray[qd.f32, 2]) -> None: ... + + +# NDArray subscript: dtype only +def kernel_dtype(a: qd.types.NDArray[qd.i32]) -> None: ... + + +# NDArray bare (no subscript) +def kernel_bare(a: qd.types.NDArray) -> None: ... + + +# Multiple NDArray args with different types +def multi_args( + a: qd.types.NDArray[qd.f32, 2], + b: qd.types.NDArray[qd.i32, 1], + c: qd.types.NDArray, +) -> None: ... + + +# Top-level NDArray alias (accessible via qd.types) +def top_level(a: qd.types.NDArray[qd.f32, 2]) -> None: ... + + +# Return types +def make_arr() -> qd.types.NDArray[qd.f32, 2]: ... + + +# Optional wrapping +def maybe_arr(x: Optional[qd.types.NDArray[qd.f32, 2]]) -> None: ... + + +# Variable annotations +field1: qd.types.NDArray[qd.f32, 2] +field2: qd.types.NDArray + + +# In class body +class MyModel: + buf: qd.types.NDArray[qd.f32, 3] + + def forward(self, x: qd.types.NDArray[qd.f32, 2]) -> qd.types.NDArray[qd.f32, 2]: ... + + +# Access via qd.types submodule +def via_types(x: qd.types.NDArray[qd.types.f32, 2]) -> None: ... diff --git a/tests/python/test_shared_array.py b/tests/python/test_shared_array.py index 6999075d6e..e6ade1739b 100644 --- a/tests/python/test_shared_array.py +++ b/tests/python/test_shared_array.py @@ -57,8 +57,10 @@ def test_shared_array_not_accumulated_across_offloads(num_dim, first_shape_delta max_shared_bytes = qd.lang.impl.get_max_shared_memory_bytes(is_lowerbound_ok=True) # 75% of max shared memory in bytes, converted to element counts shared_array_bytes = int(0.75 * max_shared_bytes) - num_elems_1 = shared_array_bytes // qd._lib.core.data_type_size(dtype1) + first_shape_delta_size - num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(dtype2) + num_elems_1 = ( + shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype1)) + first_shape_delta_size + ) + num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype2)) # Build 1D or 2D shape tuples with the same total number of elements. # For 2D, split into (block_dim, num_elems // block_dim).