Skip to content

Commit e9fc00e

Browse files
author
peng.li24
committed
refactor(numpycpp): update init_py.h and manipulation_py.h
1 parent a625e82 commit e9fc00e

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

numpycpp/init_py.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,20 +215,22 @@ inline py::array geomspace(py::object start_o, py::object stop_o,
215215

216216
// ── numpy.eye ────────────────────────────────────────────────────────────────
217217

218-
/// numpy.eye(N, M=N, k=0, dtype=float64)
218+
/// numpy.eye(N, M=None, k=0, dtype=float64) — M=None 表示方阵,严格对齐 numpy API
219219
template<typename T>
220-
py::array_t<T> eye(py::ssize_t N, py::ssize_t M = -1, int k = 0) {
220+
py::array_t<T> eye(py::ssize_t N, py::object M_obj = py::none(), int k = 0) {
221221
if (N < 0) throw std::invalid_argument("eye: N must be >= 0");
222222
size_t Ns = static_cast<size_t>(N);
223-
size_t Ms = (M < 0) ? Ns : static_cast<size_t>(M);
224-
py::array_t<T> result({N, (M < 0 ? N : M)});
223+
py::ssize_t M_val = M_obj.is_none() ? N : M_obj.cast<py::ssize_t>();
224+
if (M_val < 0) throw std::invalid_argument("eye: M must be >= 0");
225+
size_t Ms = static_cast<size_t>(M_val);
226+
py::array_t<T> result({N, M_val});
225227
numpy::eye(static_cast<T*>(result.request().ptr), Ns, Ms, k);
226228
return result;
227229
}
228230

229-
inline py::array_t<double> eye(py::ssize_t N, py::ssize_t M = -1,
231+
inline py::array_t<double> eye(py::ssize_t N, py::object M_obj = py::none(),
230232
int k = 0) {
231-
return eye<double>(N, M, k);
233+
return eye<double>(N, M_obj, k);
232234
}
233235

234236
// ── numpy.identity ───────────────────────────────────────────────────────────

numpycpp/manipulation_py.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,12 @@ py::array_t<T> take(const py::array_t<T>& arr,
478478

479479
// ── numpy.compress ────────────────────────────────────────────────────────────
480480

481-
/// numpy.compress(condition, a)
481+
/// numpy.compress(condition, a) — condition FIRST, matching numpy's order
482482
template<typename T>
483-
py::array_t<T> compress(const py::array_t<T>& arr,
484-
const py::array_t<bool>& mask) {
485-
auto buf = arr.request();
486-
auto mbuf = mask.request();
483+
py::array_t<T> compress(const py::array_t<bool>& condition,
484+
const py::array_t<T>& a) {
485+
auto mbuf = condition.request();
486+
auto buf = a.request();
487487
size_t use = static_cast<size_t>(std::min(buf.size, mbuf.size));
488488
const bool* m = static_cast<const bool*>(mbuf.ptr);
489489
size_t cnt = 0;

0 commit comments

Comments
 (0)