Skip to content

Comments

feat[next]: add support for array_namespace allocation#2442

Open
havogt wants to merge 57 commits intoGridTools:mainfrom
havogt:jnp_support
Open

feat[next]: add support for array_namespace allocation#2442
havogt wants to merge 57 commits intoGridTools:mainfrom
havogt:jnp_support

Conversation

@havogt
Copy link
Contributor

@havogt havogt commented Jan 16, 2026

Refactors the constructor functions to allow construction from array namespaces. This will allow to use the allocators for jax fields in a next PR.

We introduce a new concept Allocator which is the only public type for the different kind of allocators (array namespace or custom layout allocators).

In addition to the emtpy, etc. functions which are used as gtx.empty(..., allocator=..., device=..., alignment=...) by the user, we offer a FieldConstructor-class which binds the low-level details of allocation (allocator, device, alignment).

Copy link
Contributor

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions and ideas to discuss....

@havogt havogt requested a review from egparedes February 20, 2026 12:34
@havogt havogt changed the title feat[next]: add support for array_namespace allocation, enable jax.numpy testing feat[next]: add support for array_namespace allocation Feb 20, 2026
@havogt havogt requested a review from Copilot February 20, 2026 13:17
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for array namespace allocation in GT4Py's field construction API. The main changes enable users to directly pass array namespaces like numpy or cupy as allocators when creating fields, in addition to the existing GT4Py-specific field buffer allocators.

Changes:

  • Added support for array namespace allocators (e.g., numpy, cupy) in field construction
  • Implemented aligned_index handling with absolute-to-relative index conversion for non-zero origin domains
  • Refactored type narrowing from TypeGuard to TypeIs for better type safety
  • Extracted scalar type definitions to a separate gt4py._core.types module to avoid naming conflicts

Reviewed changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/next_tests/unit_tests/test_custom_layout_allocators.py New comprehensive test file for custom layout allocators and aligned_index handling
tests/next_tests/unit_tests/test_constructors.py Expanded tests to cover array namespace allocators and GPU backends
tests/next_tests/unit_tests/test_allocators.py Removed (tests moved to test_custom_layout_allocators.py)
tests/core_tests/unit_tests/test_nd_array_utils.py New tests for device translation utilities
src/gt4py/_core/types.py New module extracting scalar type definitions from definitions.py
src/gt4py/_core/definitions.py Updated to use types from core_types module
src/gt4py/_core/ndarray_utils.py Added ArrayNamespace protocol and device translation infrastructure
src/gt4py/storage/allocators.py Made ArrayUtils immutable with frozen=True
src/gt4py/next/custom_layout_allocators.py Migrated from TypeGuard to TypeIs; added aligned_index conversion; removed allocate() function
src/gt4py/next/constructors.py Major refactoring to support array namespace allocators via FieldConstructor class
src/gt4py/next/__init__.py Exported FieldConstructor in public API
src/gt4py/next/typing.py Replaced FieldBufferAllocationUtil with Allocator type alias
Multiple test/src files Updated imports from allocators to custom_layout_allocators

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. I just have some comments and questions about style and naming

Comment on lines 64 to 74
bool: type
int8: type
int16: type
int32: type
int64: type
uint8: type
uint16: type
uint32: type
uint64: type
float32: type
float64: type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: are we sure we need these to be types? I mean, if the Array API doesn't required them to be types, it might be enough to require that they are callables, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should they be callable actually? I think the only requirement is that they implement __eq__(), see https://data-apis.org/array-api/latest/API_specification/data_types.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I introduced _EqualityComparable

Comment on lines 76 to 102

def is_array_namespace(obj: Any) -> TypeGuard[ArrayNamespace]:
"""
Check whether `obj` (structurally) is a namespace of the array API.

See description in 'ArrayNamespace'.
"""

return (
hasattr(obj, "empty")
and hasattr(obj, "zeros")
and hasattr(obj, "ones")
and hasattr(obj, "full")
and hasattr(obj, "asarray")
and hasattr(obj, "bool")
and hasattr(obj, "int8")
and hasattr(obj, "int16")
and hasattr(obj, "int32")
and hasattr(obj, "int64")
and hasattr(obj, "uint8")
and hasattr(obj, "uint16")
and hasattr(obj, "uint32")
and hasattr(obj, "uint64")
and hasattr(obj, "float32")
and hasattr(obj, "float64")
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it works, but I'd rather have a more automated implementation to avoid that the protocol and the typeguard get out of sync.

_array_namespace_members: Final = tuple(vars(ArrayNamespace).keys())

# and then:
def is_array_namespace(obj: Any) -> TypeGuard[ArrayNamespace]:
    return all(hasattr(name) for name in _array_namespace_members) 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants