Skip to content

Commit 9717973

Browse files
author
peng.li24
committed
refactor(numpycpp): update elementwise_py.h
1 parent f92562e commit 9717973

1 file changed

Lines changed: 27 additions & 8 deletions

File tree

numpycpp/elementwise_py.h

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <algorithm>
1919
#include <stdexcept>
2020
#include <cstdint>
21+
#include <type_traits>
2122

2223
namespace py = pybind11;
2324

@@ -267,16 +268,33 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
267268
auto buf = arr.request();
268269
auto dt = arr.dtype();
269270

271+
// py::dtype::of<float>() / py::dtype::of<double>() 在 Python 传入的
272+
// numpy 数组上可能不匹配(已知 pybind11 问题)。用 dtype.kind() + itemsize 回退。
273+
char _kind = dt.kind();
274+
bool _is_f32 = (_kind == 'f' && buf.itemsize == 4);
275+
bool _is_f64 = (_kind == 'f' && buf.itemsize == 8);
276+
bool _is_i32 = (_kind == 'i' && buf.itemsize == 4);
277+
bool _is_i64 = (_kind == 'i' && buf.itemsize == 8);
278+
bool _is_bool = (_kind == 'b');
279+
280+
#define _ASTYPE_MATCH(SrcT) \
281+
(dt.is(py::dtype::of<SrcT>()) || \
282+
(std::is_same<SrcT, float>::value && _is_f32) || \
283+
(std::is_same<SrcT, double>::value && _is_f64) || \
284+
(std::is_same<SrcT, int>::value && _is_i32) || \
285+
(std::is_same<SrcT, int64_t>::value&& _is_i64) || \
286+
(std::is_same<SrcT, bool>::value && _is_bool))
287+
270288
#define _ASTYPE_CASE(SrcT, dst_str, DstT) \
271-
if (dt.is(py::dtype::of<SrcT>()) && (dtype == dst_str)) { \
289+
if (_ASTYPE_MATCH(SrcT) && (dtype == dst_str)) { \
272290
py::array_t<DstT> r(buf.shape); \
273291
numpy::astype<DstT, SrcT>(static_cast<const SrcT*>(buf.ptr), \
274292
static_cast<DstT*>(r.request().ptr), buf.size); \
275293
return r; \
276294
}
277295

278296
// float64
279-
_ASTYPE_CASE(double, "float64", double) // 自转换
297+
_ASTYPE_CASE(double, "float64", double)
280298
_ASTYPE_CASE(double, "double", double)
281299
_ASTYPE_CASE(double, "float32", float)
282300
_ASTYPE_CASE(double, "float", float)
@@ -287,14 +305,14 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
287305
// float32
288306
_ASTYPE_CASE(float, "float64", double)
289307
_ASTYPE_CASE(float, "double", double)
290-
_ASTYPE_CASE(float, "float", double) // numpy 约定: np.float32(1).astype(float) → float64
291-
_ASTYPE_CASE(float, "float32", float) // 自转换: 无操作
308+
_ASTYPE_CASE(float, "float", double)
309+
_ASTYPE_CASE(float, "float32", float)
292310
_ASTYPE_CASE(float, "int", int)
293311
_ASTYPE_CASE(float, "int32", int)
294312
_ASTYPE_CASE(float, "int64", int64_t)
295313
_ASTYPE_CASE(float, "bool", bool)
296314
// int32
297-
_ASTYPE_CASE(int, "int32", int) // 自转换
315+
_ASTYPE_CASE(int, "int32", int)
298316
_ASTYPE_CASE(int, "int", int)
299317
_ASTYPE_CASE(int, "float64", double)
300318
_ASTYPE_CASE(int, "double", double)
@@ -303,7 +321,7 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
303321
_ASTYPE_CASE(int, "int64", int64_t)
304322
_ASTYPE_CASE(int, "bool", bool)
305323
// int64
306-
_ASTYPE_CASE(int64_t, "int64", int64_t) // 自转换
324+
_ASTYPE_CASE(int64_t, "int64", int64_t)
307325
_ASTYPE_CASE(int64_t, "float64", double)
308326
_ASTYPE_CASE(int64_t, "double", double)
309327
_ASTYPE_CASE(int64_t, "float32", float)
@@ -312,7 +330,7 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
312330
_ASTYPE_CASE(int64_t, "int32", int)
313331
_ASTYPE_CASE(int64_t, "bool", bool)
314332
// bool
315-
_ASTYPE_CASE(bool, "bool", bool) // 自转换
333+
_ASTYPE_CASE(bool, "bool", bool)
316334
_ASTYPE_CASE(bool, "float64", double)
317335
_ASTYPE_CASE(bool, "double", double)
318336
_ASTYPE_CASE(bool, "float32", float)
@@ -321,11 +339,12 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
321339
_ASTYPE_CASE(bool, "int32", int)
322340
_ASTYPE_CASE(bool, "int64", int64_t)
323341
#undef _ASTYPE_CASE
342+
#undef _ASTYPE_MATCH
324343

325344
throw std::runtime_error(
326345
"astype: unsupported conversion " + std::string(py::str(dt)) +
327346
" -> " + dtype + ". Available targets: float64/double, float32/float, "
328-
"int/int32, int64, bool. Also accepts self-conversion (e.g. float32->float32).");
347+
"int/int32, int64, bool.");
329348
}
330349

331350
/// float64 → float32 → float64 roundtrip

0 commit comments

Comments
 (0)