1818#include < algorithm>
1919#include < stdexcept>
2020#include < cstdint>
21+ #include < type_traits>
2122
2223namespace py = pybind11;
2324
@@ -267,16 +268,33 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
267268 auto buf = arr.request ();
268269 auto dt = arr.dtype ();
269270
271+ // py::dtype::of<float>() / py::dtype::of<double>() 在 Python 传入的
272+ // numpy 数组上可能不匹配(已知 pybind11 问题)。用 dtype.kind() + itemsize 回退。
273+ char _kind = dt.kind ();
274+ bool _is_f32 = (_kind == ' f' && buf.itemsize == 4 );
275+ bool _is_f64 = (_kind == ' f' && buf.itemsize == 8 );
276+ bool _is_i32 = (_kind == ' i' && buf.itemsize == 4 );
277+ bool _is_i64 = (_kind == ' i' && buf.itemsize == 8 );
278+ bool _is_bool = (_kind == ' b' );
279+
280+ #define _ASTYPE_MATCH (SrcT ) \
281+ (dt.is (py::dtype::of<SrcT>()) || \
282+ (std::is_same<SrcT, float >::value && _is_f32) || \
283+ (std::is_same<SrcT, double >::value && _is_f64) || \
284+ (std::is_same<SrcT, int >::value && _is_i32) || \
285+ (std::is_same<SrcT, int64_t >::value&& _is_i64) || \
286+ (std::is_same<SrcT, bool >::value && _is_bool))
287+
270288#define _ASTYPE_CASE (SrcT, dst_str, DstT ) \
271- if (dt. is (py::dtype::of< SrcT>() ) && (dtype == dst_str)) { \
289+ if (_ASTYPE_MATCH ( SrcT) && (dtype == dst_str)) { \
272290 py::array_t <DstT> r (buf.shape ); \
273291 numpy::astype<DstT, SrcT>(static_cast <const SrcT*>(buf.ptr ), \
274292 static_cast <DstT*>(r.request ().ptr ), buf.size ); \
275293 return r; \
276294 }
277295
278296 // float64
279- _ASTYPE_CASE (double , " float64" , double ) // 自转换
297+ _ASTYPE_CASE (double , " float64" , double )
280298 _ASTYPE_CASE (double , " double" , double )
281299 _ASTYPE_CASE (double , " float32" , float )
282300 _ASTYPE_CASE (double , " float" , float )
@@ -287,14 +305,14 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
287305 // float32
288306 _ASTYPE_CASE (float , " float64" , double )
289307 _ASTYPE_CASE (float , " double" , double )
290- _ASTYPE_CASE (float , " float" , double ) // numpy 约定: np.float32(1).astype(float) → float64
291- _ASTYPE_CASE (float , " float32" , float ) // 自转换: 无操作
308+ _ASTYPE_CASE (float , " float" , double )
309+ _ASTYPE_CASE (float , " float32" , float )
292310 _ASTYPE_CASE (float , " int" , int )
293311 _ASTYPE_CASE (float , " int32" , int )
294312 _ASTYPE_CASE (float , " int64" , int64_t )
295313 _ASTYPE_CASE (float , " bool" , bool )
296314 // int32
297- _ASTYPE_CASE (int , " int32" , int ) // 自转换
315+ _ASTYPE_CASE (int , " int32" , int )
298316 _ASTYPE_CASE (int , " int" , int )
299317 _ASTYPE_CASE (int , " float64" , double )
300318 _ASTYPE_CASE (int , " double" , double )
@@ -303,7 +321,7 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
303321 _ASTYPE_CASE (int , " int64" , int64_t )
304322 _ASTYPE_CASE (int , " bool" , bool )
305323 // int64
306- _ASTYPE_CASE (int64_t , " int64" , int64_t ) // 自转换
324+ _ASTYPE_CASE (int64_t , " int64" , int64_t )
307325 _ASTYPE_CASE (int64_t , " float64" , double )
308326 _ASTYPE_CASE (int64_t , " double" , double )
309327 _ASTYPE_CASE (int64_t , " float32" , float )
@@ -312,7 +330,7 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
312330 _ASTYPE_CASE (int64_t , " int32" , int )
313331 _ASTYPE_CASE (int64_t , " bool" , bool )
314332 // bool
315- _ASTYPE_CASE (bool , " bool" , bool ) // 自转换
333+ _ASTYPE_CASE (bool , " bool" , bool )
316334 _ASTYPE_CASE (bool , " float64" , double )
317335 _ASTYPE_CASE (bool , " double" , double )
318336 _ASTYPE_CASE (bool , " float32" , float )
@@ -321,11 +339,12 @@ inline py::array astype(const py::array& arr, const std::string& dtype) {
321339 _ASTYPE_CASE (bool , " int32" , int )
322340 _ASTYPE_CASE (bool , " int64" , int64_t )
323341#undef _ASTYPE_CASE
342+ #undef _ASTYPE_MATCH
324343
325344 throw std::runtime_error (
326345 " astype: unsupported conversion " + std::string (py::str (dt)) +
327346 " -> " + dtype + " . Available targets: float64/double, float32/float, "
328- " int/int32, int64, bool. Also accepts self-conversion (e.g. float32->float32). " );
347+ " int/int32, int64, bool." );
329348}
330349
331350// / float64 → float32 → float64 roundtrip
0 commit comments