Skip to content

Commit 980cbb4

Browse files
author
peng.li24
committed
refactor(numpycpp): update elementwise_py.h
1 parent 20a6878 commit 980cbb4

1 file changed

Lines changed: 34 additions & 12 deletions

File tree

numpycpp/elementwise_py.h

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,21 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
269269
auto dt = arr.dtype();
270270

271271
// py::dtype::of<float>() / py::dtype::of<double>() 在 Python 传入的
272-
// numpy 数组上可能不匹配(已知 pybind11 问题)。用 dtype.kind() + itemsize 回退。
273-
char _kind = dt.kind();
274-
bool _is_f32 = (_kind == 'f' && buf.itemsize == 4);
275-
bool _is_f64 = (_kind == 'f' && buf.itemsize == 8);
276-
bool _is_i32 = (_kind == 'i' && buf.itemsize == 4);
277-
bool _is_i64 = (_kind == 'i' && buf.itemsize == 8);
278-
bool _is_bool = (_kind == 'b');
272+
// numpy 数组上可能不匹配(已知 pybind11 问题)。用 buffer format + itemsize
273+
// 回退——buf.format 来自 C API 的 PyObject_GetBuffer,不触发 Python 属性调用,
274+
// 避免 astype 内递归。
275+
char _fmt_char = buf.format.empty() ? '\0' :
276+
(buf.format[0] == '<' || buf.format[0] == '>' || buf.format[0] == '=')
277+
? buf.format[1] : buf.format[0];
278+
bool _is_f32 = (_fmt_char == 'f' && buf.itemsize == 4);
279+
bool _is_f64 = (_fmt_char == 'd' && buf.itemsize == 8) ||
280+
(_fmt_char == 'f' && buf.itemsize == 8);
281+
bool _is_i32 = (_fmt_char == 'i' && buf.itemsize == 4) ||
282+
(_fmt_char == 'l' && buf.itemsize == 4);
283+
bool _is_i64 = (_fmt_char == 'i' && buf.itemsize == 8) ||
284+
(_fmt_char == 'l' && buf.itemsize == 8) ||
285+
(_fmt_char == 'q' && buf.itemsize == 8);
286+
bool _is_bool = (_fmt_char == '?' || _fmt_char == 'b');
279287

280288
#define _ASTYPE_MATCH(SrcT) \
281289
(dt.is(py::dtype::of<SrcT>()) || \
@@ -293,19 +301,33 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
293301
return r; \
294302
}
295303

304+
// ═══════════════════════════════════════════════════════════════════════
305+
// 目标字符串语义(严格对齐 numpy ndarray.astype() 行为):
306+
//
307+
// "float64" / "double" → C++ double (64-bit)
308+
// "float32" → C++ float (32-bit)
309+
// "float" → C++ double (64-bit) ← 注意!numpy 中
310+
// np.float64(1).astype(float) → float64
311+
// np.float32(1).astype(float) → float64
312+
// np.int32(1).astype(float) → float64
313+
// 即 numpy 默认 float = float64,不是 float32
314+
// "int" / "int32" → C++ int (32-bit)
315+
// "int64" → C++ int64_t
316+
// "bool" → C++ bool
317+
// ═══════════════════════════════════════════════════════════════════════
296318
// float64
297319
_ASTYPE_CASE(double, "float64", double)
298320
_ASTYPE_CASE(double, "double", double)
299321
_ASTYPE_CASE(double, "float32", float)
300-
_ASTYPE_CASE(double, "float", float)
322+
_ASTYPE_CASE(double, "float", double) // numpy: float64.astype(float) → float64
301323
_ASTYPE_CASE(double, "int", int)
302324
_ASTYPE_CASE(double, "int32", int)
303325
_ASTYPE_CASE(double, "int64", int64_t)
304326
_ASTYPE_CASE(double, "bool", bool)
305327
// float32
306328
_ASTYPE_CASE(float, "float64", double)
307329
_ASTYPE_CASE(float, "double", double)
308-
_ASTYPE_CASE(float, "float", double)
330+
_ASTYPE_CASE(float, "float", double) // "float" → float64 (numpy 默认)
309331
_ASTYPE_CASE(float, "float32", float)
310332
_ASTYPE_CASE(float, "int", int)
311333
_ASTYPE_CASE(float, "int32", int)
@@ -317,15 +339,15 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
317339
_ASTYPE_CASE(int, "float64", double)
318340
_ASTYPE_CASE(int, "double", double)
319341
_ASTYPE_CASE(int, "float32", float)
320-
_ASTYPE_CASE(int, "float", float)
342+
_ASTYPE_CASE(int, "float", double) // "float" → float64 (numpy 默认)
321343
_ASTYPE_CASE(int, "int64", int64_t)
322344
_ASTYPE_CASE(int, "bool", bool)
323345
// int64
324346
_ASTYPE_CASE(int64_t, "int64", int64_t)
325347
_ASTYPE_CASE(int64_t, "float64", double)
326348
_ASTYPE_CASE(int64_t, "double", double)
327349
_ASTYPE_CASE(int64_t, "float32", float)
328-
_ASTYPE_CASE(int64_t, "float", float)
350+
_ASTYPE_CASE(int64_t, "float", double) // "float" → float64 (numpy 默认)
329351
_ASTYPE_CASE(int64_t, "int", int)
330352
_ASTYPE_CASE(int64_t, "int32", int)
331353
_ASTYPE_CASE(int64_t, "bool", bool)
@@ -334,7 +356,7 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
334356
_ASTYPE_CASE(bool, "float64", double)
335357
_ASTYPE_CASE(bool, "double", double)
336358
_ASTYPE_CASE(bool, "float32", float)
337-
_ASTYPE_CASE(bool, "float", float)
359+
_ASTYPE_CASE(bool, "float", double) // "float" → float64 (numpy 默认)
338360
_ASTYPE_CASE(bool, "int", int)
339361
_ASTYPE_CASE(bool, "int32", int)
340362
_ASTYPE_CASE(bool, "int64", int64_t)

0 commit comments

Comments
 (0)