Skip to content

Commit 6b0dffe

Browse files
leofangclaude
andcommitted
Address review comments: dtypes, stale cache, stream_ptr, sync notes
- Add uint16/uint32/uint64 to AOTI dtype and itemsize maps (fixes regression where these torch dtypes would raise TypeError instead of being handled by the bridge) - Clear buf._dtype when repopulating a reused StridedMemoryView to prevent returning a stale cached dtype - Reject stream_ptr=None for CUDA tensors with BufferError (matches DLPack semantics where None is ambiguous) - Add "keep in sync" comments to aoti_shim.h and aoti_shim.def per rwgk's review suggestion Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1ac154a commit 6b0dffe

3 files changed

Lines changed: 39 additions & 4 deletions

File tree

cuda_core/cuda/core/_include/aoti_shim.def

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
; Used on Windows only: 'lib /DEF:aoti_shim.def /OUT:aoti_shim.lib /MACHINE:X64'
33
; generates a minimal import library that satisfies the MSVC linker.
44
; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch').
5+
;
6+
; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations
7+
; in aoti_shim.h. build_hooks.py turns this file into the stub import library
8+
; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI
9+
; symbol must be updated in both files.
510
LIBRARY torch_cpu.dll
611
EXPORTS
712
aoti_torch_get_data_ptr
@@ -14,6 +19,9 @@ EXPORTS
1419
aoti_torch_dtype_float64
1520
aoti_torch_dtype_bfloat16
1621
aoti_torch_dtype_uint8
22+
aoti_torch_dtype_uint16
23+
aoti_torch_dtype_uint32
24+
aoti_torch_dtype_uint64
1725
aoti_torch_dtype_int8
1826
aoti_torch_dtype_int16
1927
aoti_torch_dtype_int32

cuda_core/cuda/core/_include/aoti_shim.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ typedef int32_t AOTITorchError;
5050
struct AtenTensorOpaque;
5151
typedef struct AtenTensorOpaque* AtenTensorHandle;
5252

53+
/*
54+
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
55+
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
56+
* stub import library that MSVC needs to link _tensor_bridge without making
57+
* PyTorch a build-time dependency. If you add, remove, or rename an imported
58+
* AOTI symbol here, update aoti_shim.def in the same change.
59+
*/
60+
5361
/* ---- tensor metadata --------------------------------------------------- */
5462

5563
AOTI_SHIM_API AOTITorchError aoti_torch_get_data_ptr(
@@ -74,6 +82,9 @@ AOTI_SHIM_API int32_t aoti_torch_dtype_float32(void);
7482
AOTI_SHIM_API int32_t aoti_torch_dtype_float64(void);
7583
AOTI_SHIM_API int32_t aoti_torch_dtype_bfloat16(void);
7684
AOTI_SHIM_API int32_t aoti_torch_dtype_uint8(void);
85+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint16(void);
86+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint32(void);
87+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint64(void);
7788
AOTI_SHIM_API int32_t aoti_torch_dtype_int8(void);
7889
AOTI_SHIM_API int32_t aoti_torch_dtype_int16(void);
7990
AOTI_SHIM_API int32_t aoti_torch_dtype_int32(void);

cuda_core/cuda/core/_tensor_bridge.pyx

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ cdef extern from "_include/aoti_shim.h":
8585
int32_t aoti_torch_dtype_float64()
8686
int32_t aoti_torch_dtype_bfloat16()
8787
int32_t aoti_torch_dtype_uint8()
88+
int32_t aoti_torch_dtype_uint16()
89+
int32_t aoti_torch_dtype_uint32()
90+
int32_t aoti_torch_dtype_uint64()
8891
int32_t aoti_torch_dtype_int8()
8992
int32_t aoti_torch_dtype_int16()
9093
int32_t aoti_torch_dtype_int32()
@@ -196,6 +199,9 @@ cdef dict _build_dtype_map():
196199
aoti_torch_dtype_float32(): numpy.dtype(numpy.float32),
197200
aoti_torch_dtype_float64(): numpy.dtype(numpy.float64),
198201
aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8),
202+
aoti_torch_dtype_uint16(): numpy.dtype(numpy.uint16),
203+
aoti_torch_dtype_uint32(): numpy.dtype(numpy.uint32),
204+
aoti_torch_dtype_uint64(): numpy.dtype(numpy.uint64),
199205
aoti_torch_dtype_int8(): numpy.dtype(numpy.int8),
200206
aoti_torch_dtype_int16(): numpy.dtype(numpy.int16),
201207
aoti_torch_dtype_int32(): numpy.dtype(numpy.int32),
@@ -228,6 +234,9 @@ cdef dict _build_itemsize_map():
228234
return {
229235
aoti_torch_dtype_bool(): sizeof(uint8_t),
230236
aoti_torch_dtype_uint8(): sizeof(uint8_t),
237+
aoti_torch_dtype_uint16(): sizeof(int16_t),
238+
aoti_torch_dtype_uint32(): sizeof(int32_t),
239+
aoti_torch_dtype_uint64(): sizeof(int64_t),
231240
aoti_torch_dtype_int8(): sizeof(int8_t),
232241
aoti_torch_dtype_float16(): sizeof(int16_t), # no C float16
233242
aoti_torch_dtype_bfloat16(): sizeof(int16_t), # no C bfloat16
@@ -344,6 +353,7 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None):
344353
buf = StridedMemoryView.__new__(StridedMemoryView)
345354

346355
buf.ptr = <intptr_t>data_ptr
356+
buf._dtype = None # clear cached dtype (view may be reused)
347357
# PyTorch always reports tensors as writable via both DLPack
348358
# (flags=0, no DLPACK_FLAG_BITMASK_READ_ONLY) and CAI
349359
# (__cuda_array_interface__["data"] = (ptr, False)). Tensors that
@@ -364,10 +374,16 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None):
364374
buf.is_device_accessible = True
365375

366376
# -- stream ordering (matches the DLPack contract) --
367-
if stream_ptr is not None:
368-
_stream_ptr_int = int(stream_ptr)
369-
if _stream_ptr_int != -1:
370-
sync_torch_stream(device_index, _stream_ptr_int)
377+
# stream_ptr=None is ambiguous for CUDA tensors — the caller must
378+
# explicitly choose -1 (no sync) or a valid stream pointer.
379+
if stream_ptr is None:
380+
raise BufferError(
381+
"stream_ptr=None is ambiguous for CUDA tensors; "
382+
"pass stream_ptr=-1 to opt out of synchronization, "
383+
"or pass a valid stream pointer")
384+
_stream_ptr_int = int(stream_ptr)
385+
if _stream_ptr_int != -1:
386+
sync_torch_stream(device_index, _stream_ptr_int)
371387
else:
372388
raise BufferError(
373389
f"Unsupported device type from torch tensor "

0 commit comments

Comments
 (0)