Skip to content
Draft
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ imgui.ini
!pyrightconfig.json
*.whl
*.so
stubs/
stubs/quadrants/_lib/
CHANGELOG.md
python/quadrants/_version.py
env.sh
10 changes: 10 additions & 0 deletions python/quadrants/lang/exception.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# type: ignore

import inspect

from quadrants._lib import core


Expand Down Expand Up @@ -57,6 +59,14 @@ def get_ret(needed, provided):
return QuadrantsRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}")


def get_func_signature(func):
"""Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError."""
try:
return inspect.signature(func, eval_str=True)
except (NameError, AttributeError) as e:
raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e


def handle_exception_from_cpp(exc):
if isinstance(exc, core.QuadrantsTypeError):
return QuadrantsTypeError(str(exc))
Expand Down
13 changes: 13 additions & 0 deletions stubs/quadrants/types/annotations.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Generic, TypeVar

T = TypeVar("T")

class Template(Generic[T]):
element_type: type[T]
ndim: int | None
def __init__(self, element_type: type[T] = ..., ndim: int | None = ...) -> None: ...
def __getitem__(self, i: Any) -> T: ...

template = Template

class sparse_matrix_builder: ...
10 changes: 10 additions & 0 deletions stubs/quadrants/types/compound_types.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any

class CompoundType:
def from_kernel_struct_ret(self, launch_ctx: Any, index: tuple[Any, ...]) -> Any: ...
def check_matched(self, other: Any) -> bool: ...
def to_string(self) -> str: ...

def matrix(n: int | None = ..., m: int | None = ..., dtype: Any = ...) -> Any: ...
def vector(n: int | None = ..., dtype: Any = ...) -> Any: ...
def struct(**kwargs: Any) -> Any: ...
27 changes: 27 additions & 0 deletions stubs/quadrants/types/ndarray_type.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any

class NdarrayType:
dtype: Any
ndim: int | None
needs_grad: bool | None
boundary: int

def __init__(
self,
dtype: Any = ...,
ndim: int | None = ...,
element_dim: int | None = ...,
element_shape: tuple[int, ...] | None = ...,
field_dim: int | None = ...,
needs_grad: bool | None = ...,
boundary: str = ...,
) -> None: ...
@classmethod
def __class_getitem__(cls, args: Any) -> type[NdarrayType]: ...
def __getitem__(self, i: Any) -> Any: ...
def __setitem__(self, i: Any, v: Any) -> None: ...
def __repr__(self) -> str: ...
def __str__(self) -> str: ...

ndarray = NdarrayType
NDArray = NdarrayType
70 changes: 70 additions & 0 deletions stubs/quadrants/types/primitive_types.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Any, ClassVar, Union

from quadrants._lib.core.quadrants_python import DataTypeCxx

class PrimitiveMeta(type):
cxx: DataTypeCxx
def __call__(cls, *args: Any, **kwargs: Any) -> Any: ...
def __eq__(cls, other: object) -> bool: ...
def __ne__(cls, other: object) -> bool: ...
def __hash__(cls) -> int: ...
def __repr__(cls) -> str: ...
def __getattr__(cls, name: str) -> Any: ...

class PrimitiveBase(metaclass=PrimitiveMeta):
cxx: ClassVar[DataTypeCxx]

class f16(PrimitiveBase): ...
class f32(PrimitiveBase): ...
class f64(PrimitiveBase): ...
class i8(PrimitiveBase): ...
class i16(PrimitiveBase): ...
class i32(PrimitiveBase): ...
class i64(PrimitiveBase): ...
class u1(PrimitiveBase): ...
class u8(PrimitiveBase): ...
class u16(PrimitiveBase): ...
class u32(PrimitiveBase): ...
class u64(PrimitiveBase): ...

float16 = f16
float32 = f32
float64 = f64
int8 = i8
int16 = i16
int32 = i32
int64 = i64
uint1 = u1
uint8 = u8
uint16 = u16
uint32 = u32
uint64 = u64

# Raw C++ DataType instances (internal use)
f16_cxx: DataTypeCxx
f32_cxx: DataTypeCxx
f64_cxx: DataTypeCxx
i8_cxx: DataTypeCxx
i16_cxx: DataTypeCxx
i32_cxx: DataTypeCxx
i64_cxx: DataTypeCxx
u1_cxx: DataTypeCxx
u8_cxx: DataTypeCxx
u16_cxx: DataTypeCxx
u32_cxx: DataTypeCxx
u64_cxx: DataTypeCxx

class RefType:
tp: Any
def __init__(self, tp: Any) -> None: ...

def ref(tp: Any) -> RefType: ...

real_types: set[type[PrimitiveBase] | type]
real_type_ids: set[int]
integer_types: set[type[PrimitiveBase] | type]
integer_type_ids: set[int]
all_types: set[type[PrimitiveBase] | type]
cxx_type_ids: set[int]
type_ids: set[int]
_python_primitive_types = Union[int, float, bool, str, None]
6 changes: 6 additions & 0 deletions stubs/quadrants/types/utils.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any

def is_signed(dt: Any) -> bool: ...
def is_integral(dt: Any) -> bool: ...
def is_real(dt: Any) -> bool: ...
def is_tensor(dt: Any) -> bool: ...
21 changes: 16 additions & 5 deletions tests/python/pyright/test_ndarray_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is a test file. It just has to exist, to check that pyright works with it.
# Pyright test: NDArray annotations accepted by the type checker and functional at runtime.

import quadrants as qd

Expand All @@ -7,29 +7,40 @@
qd.init(arch=qd.cpu)


# Legacy call syntax (still works, pyright warns about call expressions in type positions)
@qd.kernel
def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k1(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]


@qd.kernel()
def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k2(a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]


# New subscript syntax (preferred, no pyright warnings)
@qd.kernel
def k3(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ...


@qd.data_oriented
class SomeClass:
@qd.kernel
def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k1(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]

@qd.kernel()
def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ...
def k2(self, a: qd.types.ndarray(), b: qd.types.NDArray, c: qd.types.NDArray[qd.i32, 1]) -> None: ... # type: ignore[reportInvalidTypeForm]

@qd.kernel
def k3(self, a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32], c: qd.types.NDArray) -> None: ...


@test_utils.test()
def test_ndarray_type():
a = qd.ndarray(qd.i32, (10,))
k1(a, a, a)
k2(a, a, a)
k3(a, a, a)

some_class = SomeClass()
some_class.k1(a, a, a)
some_class.k2(a, a, a)
some_class.k3(a, a, a)
73 changes: 73 additions & 0 deletions tests/python/pyright/test_primitive_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Pyright test: primitive dtype classes and NDArray subscript syntax.

This file must produce zero pyright errors. It validates that:
- Primitive dtypes (f32, i32, etc.) work as type annotations
- NDArray[dtype, ndim] subscript syntax is accepted
- from __future__ import annotations (stringified) works
- NDArray works via qd.types and top-level qd.NDArray
- Return types and Optional wrappers work
"""

from __future__ import annotations

from typing import Optional

import quadrants as qd


# Primitive types as annotations
def accept_f32(x: qd.f32) -> None: ...


def accept_i32(x: qd.i32) -> None: ...


def accept_any_dtype(x: qd.f32 | qd.i32 | qd.u8) -> None: ...


# NDArray subscript: dtype + ndim
def kernel_2d(a: qd.types.NDArray[qd.f32, 2]) -> None: ...


# NDArray subscript: dtype only
def kernel_dtype(a: qd.types.NDArray[qd.i32]) -> None: ...


# NDArray bare (no subscript)
def kernel_bare(a: qd.types.NDArray) -> None: ...


# Multiple NDArray args with different types
def multi_args(
a: qd.types.NDArray[qd.f32, 2],
b: qd.types.NDArray[qd.i32, 1],
c: qd.types.NDArray,
) -> None: ...


# Top-level NDArray alias (accessible via qd.types)
def top_level(a: qd.types.NDArray[qd.f32, 2]) -> None: ...


# Return types
def make_arr() -> qd.types.NDArray[qd.f32, 2]: ...


# Optional wrapping
def maybe_arr(x: Optional[qd.types.NDArray[qd.f32, 2]]) -> None: ...


# Variable annotations
field1: qd.types.NDArray[qd.f32, 2]
field2: qd.types.NDArray


# In class body
class MyModel:
buf: qd.types.NDArray[qd.f32, 3]

def forward(self, x: qd.types.NDArray[qd.f32, 2]) -> qd.types.NDArray[qd.f32, 2]: ...


# Access via qd.types submodule
def via_types(x: qd.types.NDArray[qd.types.f32, 2]) -> None: ...
6 changes: 4 additions & 2 deletions tests/python/test_shared_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def test_shared_array_not_accumulated_across_offloads(num_dim, first_shape_delta
max_shared_bytes = qd.lang.impl.get_max_shared_memory_bytes(is_lowerbound_ok=True)
# 75% of max shared memory in bytes, converted to element counts
shared_array_bytes = int(0.75 * max_shared_bytes)
num_elems_1 = shared_array_bytes // qd._lib.core.data_type_size(dtype1) + first_shape_delta_size
num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(dtype2)
num_elems_1 = (
shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype1)) + first_shape_delta_size
)
num_elems_2 = shared_array_bytes // qd._lib.core.data_type_size(qd.lang.util.cook_dtype(dtype2))

# Build 1D or 2D shape tuples with the same total number of elements.
# For 2D, split into (block_dim, num_elems // block_dim).
Expand Down
Loading