Skip to content

Commit 7d2fb30

Browse files
author
peng.li24
committed
test: update module.cpp and add test_astype.py
1 parent 980cbb4 commit 7d2fb30

2 files changed

Lines changed: 74 additions & 8 deletions

File tree

tests/module.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,28 @@ PYBIND11_MODULE(numpycpp, m) {
5050
});
5151

5252
// -- astype dtype 匹配诊断 -------------------------------------------------
53-
// 返回 arr 的 dtype 匹配详情,用于验证 py::dtype::of<T>() vs kind()+itemsize 行为
53+
// 返回 arr 的 dtype 匹配详情。仅用 buf.format + itemsize(C API),
54+
// 不调用 dt.kind()(避免 Python↔C++ 递归)。
5455
m.def("_diag_astype_dtype", [](const py::array& arr) -> py::dict {
5556
auto buf = arr.request();
5657
auto dt = arr.dtype();
5758

58-
char kind = dt.kind();
59-
bool is_f32 = (kind == 'f' && buf.itemsize == 4);
60-
bool is_f64 = (kind == 'f' && buf.itemsize == 8);
61-
59+
char fc = buf.format.empty() ? '\0' :
60+
(buf.format[0] == '<' || buf.format[0] == '>' || buf.format[0] == '=')
61+
? buf.format[1] : buf.format[0];
62+
bool is_f32 = (fc == 'f' && buf.itemsize == 4);
63+
bool is_f64 = (fc == 'd' && buf.itemsize == 8) ||
64+
(fc == 'f' && buf.itemsize == 8);
6265
bool dt_of_f32 = dt.is(py::dtype::of<float>());
6366
bool dt_of_f64 = dt.is(py::dtype::of<double>());
6467

6568
py::dict result;
6669
result["dtype_str"] = py::str(dt);
67-
result["kind"] = std::string(1, kind);
70+
result["format_char"] = std::string(1, fc);
6871
result["itemsize"] = buf.itemsize;
6972
result["format"] = buf.format;
70-
result["is_f32(kind)"] = is_f32;
71-
result["is_f64(kind)"] = is_f64;
73+
result["is_f32(fmt)"] = is_f32;
74+
result["is_f64(fmt)"] = is_f64;
7275
result["dtype.is(float)"] = dt_of_f32;
7376
result["dtype.is(double)"]= dt_of_f64;
7477
result["fallback_works"] = (is_f32 || is_f64 || dt_of_f32 || dt_of_f64);

tests/test_astype.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""numpy::astype C++ 单测 — 全部 dtype 转换 + 递归安全性。
2+
运行: pytest tests/test_astype.py -v
3+
"""
4+
import numpy as np, pytest, numpycpp as cpp
5+
6+
# all source × target combinations
7+
_SRC = {"f64": lambda: np.array([1.5,-2.7,3.1], dtype=np.float64),
8+
"f32": lambda: np.array([1.5,-2.7,3.1], dtype=np.float32),
9+
"i32": lambda: np.array([1,-2,3], dtype=np.int32),
10+
"i64": lambda: np.array([1,-2,3], dtype=np.int64),
11+
"bool": lambda: np.array([True,False,True])}
12+
# numpy 约定: ndarray.astype(float) → float64, astype(int) → int32
13+
_EXPECT = {
14+
"float64": np.float64, "double": np.float64,
15+
"float32": np.float32, "float": np.float64,
16+
"int": np.int32, "int32": np.int32,
17+
"int64": np.int64, "bool": np.bool_,
18+
}
19+
20+
@pytest.mark.parametrize("src", _SRC)
21+
@pytest.mark.parametrize("dst,exp_dt", list(_EXPECT.items()))
22+
def test_astype(src, dst, exp_dt):
23+
a = _SRC[src]()
24+
r = np.asarray(cpp.astype(a, dst))
25+
assert r.dtype == exp_dt, f"{src}{dst}: got {r.dtype}, expected {exp_dt}"
26+
assert np.array_equal(r, a.astype(exp_dt)), f"{src}{dst}: value mismatch"
27+
28+
@pytest.mark.parametrize("label,arr,expect_fallback", [
29+
("f32 LE", np.array([1.5],dtype=np.float32), True),
30+
("f32 BE", np.array([1.5],dtype='>f4'), True),
31+
("f64 LE", np.array([1.5],dtype=np.float64), True),
32+
("i32", np.array([1],dtype=np.int32), False), # float fallback 不适用
33+
("bool", np.array([True]), False), # float fallback 不适用
34+
])
35+
def test_dtype_diag(label, arr, expect_fallback):
36+
info = cpp._diag_astype_dtype(arr)
37+
assert info["fallback_works"] == expect_fallback, \
38+
f"{label}: fallback_works={info['fallback_works']} expect={expect_fallback}"
39+
40+
def test_no_recursion_32():
41+
for i in range(32):
42+
a = np.random.RandomState(i).randn(4,2).astype(np.float32)
43+
assert np.asarray(cpp.astype(a,"float64")).dtype == np.float64
44+
45+
def test_no_recursion_1k():
46+
a = np.array([1.5,2.7,3.1], dtype=np.float32)
47+
for _ in range(1000):
48+
r = np.asarray(cpp.astype(a,"float64"))
49+
assert r.dtype == np.float64
50+
51+
def test_empty():
52+
assert np.asarray(cpp.astype(np.array([],dtype=np.float32),"float64")).size == 0
53+
54+
def test_unsupported_raises():
55+
with pytest.raises(RuntimeError): cpp.astype(np.array([1.0]),"complex64")
56+
57+
def test_self_noop():
58+
for dt,val in [(np.float64,1.5),(np.float32,1.5),(np.int32,1),(np.int64,1),(bool,True)]:
59+
assert np.asarray(cpp.astype(np.array([val],dtype=dt),str(np.dtype(dt)))).dtype == dt
60+
61+
if __name__ == "__main__":
62+
import sys,os; sys.path.insert(0,os.path.dirname(os.path.abspath(__file__)))
63+
sys.exit(pytest.main([__file__,"-v","--tb=short"]))

0 commit comments

Comments
 (0)