|
| 1 | +"""numpy::astype C++ 单测 — 全部 dtype 转换 + 递归安全性。 |
| 2 | +运行: pytest tests/test_astype.py -v |
| 3 | +""" |
| 4 | +import numpy as np, pytest, numpycpp as cpp |
| 5 | + |
| 6 | +# all source × target combinations |
| 7 | +_SRC = {"f64": lambda: np.array([1.5,-2.7,3.1], dtype=np.float64), |
| 8 | + "f32": lambda: np.array([1.5,-2.7,3.1], dtype=np.float32), |
| 9 | + "i32": lambda: np.array([1,-2,3], dtype=np.int32), |
| 10 | + "i64": lambda: np.array([1,-2,3], dtype=np.int64), |
| 11 | + "bool": lambda: np.array([True,False,True])} |
| 12 | +# numpy 约定: ndarray.astype(float) → float64, astype(int) → int32 |
| 13 | +_EXPECT = { |
| 14 | + "float64": np.float64, "double": np.float64, |
| 15 | + "float32": np.float32, "float": np.float64, |
| 16 | + "int": np.int32, "int32": np.int32, |
| 17 | + "int64": np.int64, "bool": np.bool_, |
| 18 | +} |
| 19 | + |
| 20 | +@pytest.mark.parametrize("src", _SRC) |
| 21 | +@pytest.mark.parametrize("dst,exp_dt", list(_EXPECT.items())) |
| 22 | +def test_astype(src, dst, exp_dt): |
| 23 | + a = _SRC[src]() |
| 24 | + r = np.asarray(cpp.astype(a, dst)) |
| 25 | + assert r.dtype == exp_dt, f"{src}→{dst}: got {r.dtype}, expected {exp_dt}" |
| 26 | + assert np.array_equal(r, a.astype(exp_dt)), f"{src}→{dst}: value mismatch" |
| 27 | + |
| 28 | +@pytest.mark.parametrize("label,arr,expect_fallback", [ |
| 29 | + ("f32 LE", np.array([1.5],dtype=np.float32), True), |
| 30 | + ("f32 BE", np.array([1.5],dtype='>f4'), True), |
| 31 | + ("f64 LE", np.array([1.5],dtype=np.float64), True), |
| 32 | + ("i32", np.array([1],dtype=np.int32), False), # float fallback 不适用 |
| 33 | + ("bool", np.array([True]), False), # float fallback 不适用 |
| 34 | +]) |
| 35 | +def test_dtype_diag(label, arr, expect_fallback): |
| 36 | + info = cpp._diag_astype_dtype(arr) |
| 37 | + assert info["fallback_works"] == expect_fallback, \ |
| 38 | + f"{label}: fallback_works={info['fallback_works']} expect={expect_fallback}" |
| 39 | + |
| 40 | +def test_no_recursion_32(): |
| 41 | + for i in range(32): |
| 42 | + a = np.random.RandomState(i).randn(4,2).astype(np.float32) |
| 43 | + assert np.asarray(cpp.astype(a,"float64")).dtype == np.float64 |
| 44 | + |
| 45 | +def test_no_recursion_1k(): |
| 46 | + a = np.array([1.5,2.7,3.1], dtype=np.float32) |
| 47 | + for _ in range(1000): |
| 48 | + r = np.asarray(cpp.astype(a,"float64")) |
| 49 | + assert r.dtype == np.float64 |
| 50 | + |
| 51 | +def test_empty(): |
| 52 | + assert np.asarray(cpp.astype(np.array([],dtype=np.float32),"float64")).size == 0 |
| 53 | + |
| 54 | +def test_unsupported_raises(): |
| 55 | + with pytest.raises(RuntimeError): cpp.astype(np.array([1.0]),"complex64") |
| 56 | + |
| 57 | +def test_self_noop(): |
| 58 | + for dt,val in [(np.float64,1.5),(np.float32,1.5),(np.int32,1),(np.int64,1),(bool,True)]: |
| 59 | + assert np.asarray(cpp.astype(np.array([val],dtype=dt),str(np.dtype(dt)))).dtype == dt |
| 60 | + |
| 61 | +if __name__ == "__main__": |
| 62 | + import sys,os; sys.path.insert(0,os.path.dirname(os.path.abspath(__file__))) |
| 63 | + sys.exit(pytest.main([__file__,"-v","--tb=short"])) |
0 commit comments