-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpybind_example.cpp
More file actions
76 lines (62 loc) · 2.21 KB
/
Copy pathpybind_example.cpp
File metadata and controls
76 lines (62 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// pybind_example.cpp
// Compiled with pybind11
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <vector>
#include <cmath>
#include <string>
namespace py = pybind11;
// A simple function
double add(double a, double b) {
return a + b;
}
// A function that takes and returns a numpy array
py::array_t<double> scale_array(py::array_t<double> input, double factor) {
auto buf = input.request();
auto *ptr = static_cast<double *>(buf.ptr);
auto result = py::array_t<double>(buf.size);
auto res_buf = result.request();
auto *res_ptr = static_cast<double *>(res_buf.ptr);
for (ssize_t i = 0; i < buf.size; i++) {
res_ptr[i] = ptr[i] * factor;
}
return result;
}
// A simple class
class Integrator {
double a_, b_;
int n_;
public:
Integrator(double a, double b, int n) : a_(a), b_(b), n_(n) {}
double trapezoid(py::object func) {
double h = (b_ - a_) / n_;
double sum = 0.5 * (func(a_).cast<double>() + func(b_).cast<double>());
for (int i = 1; i < n_; i++) {
sum += func(a_ + i * h).cast<double>();
}
return sum * h;
}
std::string describe() const {
return "Integrator(a=" + std::to_string(a_) + ", b=" + std::to_string(b_)
+ ", n=" + std::to_string(n_) + ")";
}
double get_a() const { return a_; }
double get_b() const { return b_; }
int get_n() const { return n_; }
};
PYBIND11_MODULE(pybind_example, m) {
m.doc() = "pybind11 example module";
m.def("add", &add, "Add two numbers", py::arg("a"), py::arg("b"));
m.def("scale_array", &scale_array, "Scale a numpy array",
py::arg("input"), py::arg("factor"));
py::class_<Integrator>(m, "Integrator")
.def(py::init<double, double, int>(),
py::arg("a"), py::arg("b"), py::arg("n") = 1000)
.def("trapezoid", &Integrator::trapezoid, py::arg("func"))
.def("describe", &Integrator::describe)
.def_property_readonly("a", &Integrator::get_a)
.def_property_readonly("b", &Integrator::get_b)
.def_property_readonly("n", &Integrator::get_n)
.def("__repr__", &Integrator::describe);
}