feat[next]: add support for array_namespace allocation#2442
feat[next]: add support for array_namespace allocation#2442havogt wants to merge 57 commits intoGridTools:mainfrom
Conversation
egparedes
left a comment
There was a problem hiding this comment.
I have some questions and ideas to discuss....
tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
This PR adds support for array namespace allocation in GT4Py's field construction API. The main changes enable users to directly pass array namespaces like numpy or cupy as allocators when creating fields, in addition to the existing GT4Py-specific field buffer allocators.
Changes:
- Added support for array namespace allocators (e.g., numpy, cupy) in field construction
- Implemented aligned_index handling with absolute-to-relative index conversion for non-zero origin domains
- Refactored type narrowing from TypeGuard to TypeIs for better type safety
- Extracted scalar type definitions to a separate
gt4py._core.typesmodule to avoid naming conflicts
Reviewed changes
Copilot reviewed 27 out of 27 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
tests/next_tests/unit_tests/test_custom_layout_allocators.py |
New comprehensive test file for custom layout allocators and aligned_index handling |
tests/next_tests/unit_tests/test_constructors.py |
Expanded tests to cover array namespace allocators and GPU backends |
tests/next_tests/unit_tests/test_allocators.py |
Removed (tests moved to test_custom_layout_allocators.py) |
tests/core_tests/unit_tests/test_nd_array_utils.py |
New tests for device translation utilities |
src/gt4py/_core/types.py |
New module extracting scalar type definitions from definitions.py |
src/gt4py/_core/definitions.py |
Updated to use types from core_types module |
src/gt4py/_core/ndarray_utils.py |
Added ArrayNamespace protocol and device translation infrastructure |
src/gt4py/storage/allocators.py |
Made ArrayUtils immutable with frozen=True |
src/gt4py/next/custom_layout_allocators.py |
Migrated from TypeGuard to TypeIs; added aligned_index conversion; removed allocate() function |
src/gt4py/next/constructors.py |
Major refactoring to support array namespace allocators via FieldConstructor class |
src/gt4py/next/__init__.py |
Exported FieldConstructor in public API |
src/gt4py/next/typing.py |
Replaced FieldBufferAllocationUtil with Allocator type alias |
| Multiple test/src files | Updated imports from allocators to custom_layout_allocators |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
egparedes
left a comment
There was a problem hiding this comment.
Overall looks good. I just have some comments and questions about style and naming
src/gt4py/_core/ndarray_utils.py
Outdated
| bool: type | ||
| int8: type | ||
| int16: type | ||
| int32: type | ||
| int64: type | ||
| uint8: type | ||
| uint16: type | ||
| uint32: type | ||
| uint64: type | ||
| float32: type | ||
| float64: type |
There was a problem hiding this comment.
Question: are we sure we need these to be types? I mean, if the Array API doesn't required them to be types, it might be enough to require that they are callables, right?
There was a problem hiding this comment.
Why should they be callable actually? I think the only requirement is that they implement __eq__(), see https://data-apis.org/array-api/latest/API_specification/data_types.html
There was a problem hiding this comment.
I introduced _EqualityComparable
|
|
||
| def is_array_namespace(obj: Any) -> TypeGuard[ArrayNamespace]: | ||
| """ | ||
| Check whether `obj` (structurally) is a namespace of the array API. | ||
|
|
||
| See description in 'ArrayNamespace'. | ||
| """ | ||
|
|
||
| return ( | ||
| hasattr(obj, "empty") | ||
| and hasattr(obj, "zeros") | ||
| and hasattr(obj, "ones") | ||
| and hasattr(obj, "full") | ||
| and hasattr(obj, "asarray") | ||
| and hasattr(obj, "bool") | ||
| and hasattr(obj, "int8") | ||
| and hasattr(obj, "int16") | ||
| and hasattr(obj, "int32") | ||
| and hasattr(obj, "int64") | ||
| and hasattr(obj, "uint8") | ||
| and hasattr(obj, "uint16") | ||
| and hasattr(obj, "uint32") | ||
| and hasattr(obj, "uint64") | ||
| and hasattr(obj, "float32") | ||
| and hasattr(obj, "float64") | ||
| ) | ||
|
|
There was a problem hiding this comment.
I'm not sure if it works, but I'd rather have a more automated implementation to avoid that the protocol and the typeguard get out of sync.
_array_namespace_members: Final = tuple(vars(ArrayNamespace).keys())
# and then:
def is_array_namespace(obj: Any) -> TypeGuard[ArrayNamespace]:
return all(hasattr(name) for name in _array_namespace_members) Co-authored-by: Enrique González Paredes <enriqueg@cscs.ch>
Co-authored-by: Enrique González Paredes <enriqueg@cscs.ch>
Refactors the constructor functions to allow construction from array namespaces. This will allow to use the allocators for jax fields in a next PR.
We introduce a new concept
Allocatorwhich is the only public type for the different kind of allocators (array namespace or custom layout allocators).In addition to the
emtpy, etc. functions which are used asgtx.empty(..., allocator=..., device=..., alignment=...)by the user, we offer aFieldConstructor-class which binds the low-level details of allocation (allocator, device, alignment).