77from __future__ import annotations
88
99from dataclasses import dataclass
10- from typing import Any
10+ from typing import Any , Literal
1111
1212import numpy as np
13+ from numpy .typing import DTypeLike
1314
15+ import pyopencl as cl
1416import pyopencl .array as cla
1517from pytools import memoize
1618from pytools .tag import Tag , Taggable , ToTagSetConvertible
@@ -74,6 +76,9 @@ class TaggableCLArray(cla.Array, Taggable):
7476 record application-specific metadata to drive the optimizations in
7577 :meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
7678 """
79+ tags : frozenset [Tag ]
80+ axes : tuple [Axis , ...]
81+
7782 def __init__ (self , cq , shape , dtype , order = "C" , allocator = None ,
7883 data = None , offset = 0 , strides = None , events = None , _flags = None ,
7984 _fast = False , _size = None , _context = None , _queue = None ,
@@ -165,13 +170,20 @@ def to_tagged_cl_array(ary: cla.Array,
165170# }}}
166171
167172
173+ _EMPTY_TAG_SET : frozenset [Tag ] = frozenset ()
174+
175+
168176# {{{ creation
169177
170- def empty (queue , shape , dtype = float , * ,
171- axes : tuple [Axis , ...] | None = None ,
172- tags : frozenset [Tag ] = frozenset (),
173- order : str = "C" ,
174- allocator = None ) -> TaggableCLArray :
178+ def empty (
179+ queue : cl .CommandQueue ,
180+ shape : tuple [int , ...] | int ,
181+ dtype : DTypeLike = float ,
182+ * , axes : tuple [Axis , ...] | None = None ,
183+ tags : frozenset [Tag ] = _EMPTY_TAG_SET ,
184+ order : Literal ["C" ] | Literal ["F" ] = "C" ,
185+ allocator : cla .Allocator | None = None ,
186+ ) -> TaggableCLArray :
175187 if dtype is not None :
176188 dtype = np .dtype (dtype )
177189
@@ -181,11 +193,15 @@ def empty(queue, shape, dtype=float, *,
181193 order = order , allocator = allocator )
182194
183195
184- def zeros (queue , shape , dtype = float , * ,
185- axes : tuple [Axis , ...] | None = None ,
186- tags : frozenset [Tag ] = frozenset (),
187- order : str = "C" ,
188- allocator = None ) -> TaggableCLArray :
196+ def zeros (
197+ queue : cl .CommandQueue ,
198+ shape : tuple [int , ...] | int ,
199+ dtype : DTypeLike = float ,
200+ * , axes : tuple [Axis , ...] | None = None ,
201+ tags : frozenset [Tag ] = _EMPTY_TAG_SET ,
202+ order : Literal ["C" ] | Literal ["F" ] = "C" ,
203+ allocator : cla .Allocator | None = None ,
204+ ) -> TaggableCLArray :
189205 result = empty (
190206 queue , shape , dtype = dtype , axes = axes , tags = tags ,
191207 order = order , allocator = allocator )
@@ -194,10 +210,13 @@ def zeros(queue, shape, dtype=float, *,
194210 return result
195211
196212
197- def to_device (queue , ary , * ,
198- axes : tuple [Axis , ...] | None = None ,
199- tags : frozenset [Tag ] = frozenset (),
200- allocator = None ):
213+ def to_device (
214+ queue : cl .CommandQueue ,
215+ ary : np .ndarray [Any ],
216+ * , axes : tuple [Axis , ...] | None = None ,
217+ tags : frozenset [Tag ] = _EMPTY_TAG_SET ,
218+ allocator : cla .Allocator | None = None ,
219+ ) -> TaggableCLArray :
201220 return to_tagged_cl_array (
202221 cla .to_device (queue , ary , allocator = allocator ),
203222 axes = axes , tags = tags )
0 commit comments