Skip to content

Commit c306153

Browse files
committed
feat: improve support for structured dtype arrays held by StridedMemoryView
1 parent 2d408c1 commit c306153

1 file changed

Lines changed: 20 additions & 5 deletions

File tree

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,7 @@ cdef class StridedMemoryView:
365365
if self.dl_tensor != NULL:
366366
self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype)
367367
elif self.metadata is not None:
368-
# TODO: this only works for built-in numeric types
369-
self._dtype = _typestr2dtype[self.metadata["typestr"]]
368+
self._dtype = _typestr2dtype(self.metadata["typestr"])
370369
return self._dtype
371370

372371

@@ -503,8 +502,24 @@ _builtin_numeric_dtypes = [
503502
numpy.dtype("bool"),
504503
]
505504
# Doing it once to avoid repeated overhead
506-
_typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes}
507-
_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes}
505+
_TYPESTR_TO_DTYPE = {dtype.str: dtype for dtype in _builtin_numeric_dtypes}
506+
_TYPESTR_TO_ITEMSIZE = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes}
507+
508+
cpdef object _typestr2dtype(str typestr):
509+
if (dtype := _TYPESTR_TO_DTYPE.get(typestr)) is not None:
510+
return dtype
511+
512+
_TYPESTR_TO_DTYPE[typestr] = dtype = numpy.dtype(typestr)
513+
return dtype
514+
515+
516+
cdef int _typestr2itemsize(str typestr) except -1:
517+
if (itemsize := _TYPESTR_TO_ITEMSIZE.get(typestr)) is not None:
518+
return itemsize
519+
520+
dtype = _typestr2dtype(typestr)
521+
_TYPESTR_TO_ITEMSIZE[typestr] = itemsize = dtype.itemsize
522+
return itemsize
508523

509524

510525
cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
@@ -664,7 +679,7 @@ cdef _StridedLayout layout_from_cai(object metadata):
664679
cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout)
665680
cdef object shape = metadata["shape"]
666681
cdef object strides = metadata.get("strides")
667-
cdef int itemsize = _typestr2itemsize[metadata["typestr"]]
682+
cdef int itemsize = _typestr2itemsize(metadata["typestr"])
668683
layout.init_from_tuple(shape, strides, itemsize, True)
669684
return layout
670685

0 commit comments

Comments
 (0)