-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecom_h.cpp
More file actions
332 lines (261 loc) · 12 KB
/
decom_h.cpp
File metadata and controls
332 lines (261 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
//#define _CRTDBG_MAP_ALLOC //用于内存泄露检测
//#include <crtdbg.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <igraph.h>
#include "decom_h.h" // 你的函数声明头文件
namespace py = pybind11;
// 包装 components_forbidden
// C 函数声明(假设在 decom_h.h 中)
extern "C" igraph_error_t components_forbidden(
const igraph_t *graph,
igraph_vector_ptr_t *components,
igraph_vector_ptr_t *boundaries,
const igraph_vector_int_t *forbidden_vertices);
// 用于释放 igraph_vector_ptr_t 中每个元素的辅助函数
void free_vector_ptr(igraph_vector_ptr_t *vec_ptr) {
for (long i = 0; i < igraph_vector_ptr_size(vec_ptr); ++i) {
igraph_vector_int_t *v = static_cast<igraph_vector_int_t*>(VECTOR(*vec_ptr)[i]);
igraph_vector_int_destroy(v);
free(v);
}
igraph_vector_ptr_destroy(vec_ptr);
}
// C++ 封装接口,返回 Python 的 list[(component, boundary), ...]
py::list components_forbidden_wrapper(py::object graph_obj,
const std::vector<int> &forbidden_vertices = {}) {
// 从 Python igraph 对象获取 igraph_t 指针
py::object graph_capsule = graph_obj.attr("__graph_as_capsule")();
igraph_t *graph = static_cast<igraph_t*>(PyCapsule_GetPointer(graph_capsule.ptr(), nullptr));
if (!graph) throw std::runtime_error("Invalid igraph capsule");
// 初始化禁忌节点向量
igraph_vector_int_t forbidden_vec;
igraph_vector_int_init(&forbidden_vec, 0);
for (int v : forbidden_vertices) {
igraph_vector_int_push_back(&forbidden_vec, v);
}
// 初始化存储组件和边界的向量
igraph_vector_ptr_t components;
igraph_vector_ptr_init(&components, 0);
igraph_vector_ptr_t boundaries;
igraph_vector_ptr_init(&boundaries, 0);
// 调用 C 函数
igraph_error_t err = components_forbidden(graph, &components, &boundaries, &forbidden_vec);
igraph_vector_int_destroy(&forbidden_vec);
if (err != IGRAPH_SUCCESS) {
free_vector_ptr(&components);
free_vector_ptr(&boundaries);
throw std::runtime_error("components_forbidden failed");
}
// 转换为 Python list[(list[int], list[int]), ...]
py::list result;
igraph_integer_t n = igraph_vector_ptr_size(&components);
for (igraph_integer_t i = 0; i < n; i++) {
igraph_vector_int_t *comp = static_cast<igraph_vector_int_t*>(VECTOR(components)[i]);
igraph_vector_int_t *bound = static_cast<igraph_vector_int_t*>(VECTOR(boundaries)[i]);
py::list py_comp;
for (long j = 0; j < igraph_vector_int_size(comp); ++j) {
py_comp.append(VECTOR(*comp)[j]);
}
py::list py_bound;
for (long j = 0; j < igraph_vector_int_size(bound); ++j) {
py_bound.append(VECTOR(*bound)[j]);
}
result.append(py::make_tuple(py_comp, py_bound));
}
// 释放内存
free_vector_ptr(&components);
free_vector_ptr(&boundaries);
return result;
}
py::list close_separator_wrapper(py::object graph_obj,
int vertex,
const std::vector<int> &forbidden_vertices = {}) {
// 获取封装的图对象
py::object graph_capsule = graph_obj.attr("__graph_as_capsule")();
igraph_t *graph = static_cast<igraph_t*>(PyCapsule_GetPointer(graph_capsule.ptr(), nullptr));
if (!graph) throw std::runtime_error("Invalid igraph capsule");
// 初始化禁忌节点列表
igraph_vector_int_t forbidden_vec;
igraph_vector_int_init(&forbidden_vec, 0);
for (auto v : forbidden_vertices) {
igraph_vector_int_push_back(&forbidden_vec, v);
}
// 初始化 bound_b
igraph_vector_int_t bound_b;
igraph_vector_int_init(&bound_b, 0);
igraph_error_t err = close_separator(graph, vertex,
&forbidden_vec, &bound_b);
igraph_vector_int_destroy(&forbidden_vec);
if (err != IGRAPH_SUCCESS) {
igraph_vector_int_destroy(&bound_b);
throw std::runtime_error("close_separator_b failed");
}
// 转换 bound_b 为 Python 列表
py::list py_bound_b;
for (long i = 0; i < igraph_vector_int_size(&bound_b); ++i) {
py_bound_b.append(VECTOR(bound_b)[i]);
}
igraph_vector_int_destroy(&bound_b);
return py_bound_b;
}
extern "C" igraph_error_t find_convex_hull(
const igraph_t *graph,
const igraph_vector_int_t *r_nodes,
igraph_vector_int_t *H_out,
const char *method // 新增参数
);
py::list find_convex_hull_wrapper(py::object graph_obj,
const std::vector<int> &r_nodes,
const std::string &method) {
// 获取 igraph_t 指针
py::object graph_capsule = graph_obj.attr("__graph_as_capsule")();
igraph_t *graph = static_cast<igraph_t*>(PyCapsule_GetPointer(graph_capsule.ptr(), nullptr));
if (!graph) throw std::runtime_error("Invalid igraph capsule");
// 将 r_nodes 转成 igraph_vector_int_t
igraph_vector_int_t r_vec;
igraph_vector_int_init(&r_vec, 0);
for (int v : r_nodes) {
igraph_vector_int_push_back(&r_vec, v);
}
// 调用 find_convex_hull,传入 method.c_str()
igraph_vector_int_t H_vec;
igraph_vector_int_init(&H_vec, 0);
igraph_error_t err = find_convex_hull(graph, &r_vec, &H_vec, method.c_str());
igraph_vector_int_destroy(&r_vec);
if (err != IGRAPH_SUCCESS) {
igraph_vector_int_destroy(&H_vec);
throw std::runtime_error("find_convex_hull failed");
}
// 转换结果到 Python list
py::list result;
for (long i = 0; i < igraph_vector_int_size(&H_vec); ++i) {
result.append(VECTOR(H_vec)[i]);
}
igraph_vector_int_destroy(&H_vec);
return result;
}
// 辅助函数,将 igraph_vector_ptr_t 转换为 Python list(嵌套 list)
py::list igraph_vector_ptr_to_pylist(igraph_vector_ptr_t *vec_ptr) {
py::list result;
igraph_integer_t n = igraph_vector_ptr_size(vec_ptr);
for (igraph_integer_t i = 0; i < n; i++) {
igraph_vector_int_t *v = (igraph_vector_int_t *)VECTOR(*vec_ptr)[i];
py::list inner_list;
for (int j = 0; j < igraph_vector_int_size(v); j++) {
inner_list.append(VECTOR(*v)[j]);
}
result.append(inner_list);
}
return result;
}
// 递归分解包装,直接接收 Python igraph.Graph 对象,内部调用 __graph_as_capsule 获取指针
py::list recursive_decom_wrapper(py::object graph_obj, const std::string &method = "cmsa") {
// 通过 __graph_as_capsule 获取底层指针
py::object graph_capsule = graph_obj.attr("__graph_as_capsule")();
igraph_t *g = static_cast<igraph_t *>(PyCapsule_GetPointer(graph_capsule.ptr(), nullptr));
if (!g) throw std::runtime_error("Invalid igraph capsule");
// 1. 初始化两个输出容器
igraph_vector_ptr_t atoms; // 对应 atoms (原 blocks)
igraph_vector_ptr_t separators; // 对应 separators
igraph_vector_ptr_init(&atoms, 0);
igraph_vector_ptr_init(&separators, 0);
igraph_error_t err = IGRAPH_SUCCESS;
// 2. 调用修改后的 recursive_decom 函数
try {
err = recursive_decom(g, method.c_str(), &atoms, &separators);
} catch (...) {
// 捕获 C++ 异常,确保即使失败也能进行清理
err = IGRAPH_EINTERNAL; // 使用一个内部错误代码
}
if (err != IGRAPH_SUCCESS) {
// 3. 错误发生时,清理已初始化的容器
for (igraph_integer_t i = 0; i < igraph_vector_ptr_size(&atoms); i++) {
igraph_vector_int_destroy((igraph_vector_int_t *)VECTOR(atoms)[i]);
free(VECTOR(atoms)[i]);
}
igraph_vector_ptr_destroy(&atoms);
for (igraph_integer_t i = 0; i < igraph_vector_ptr_size(&separators); i++) {
igraph_vector_int_destroy((igraph_vector_int_t *)VECTOR(separators)[i]);
free(VECTOR(separators)[i]);
}
igraph_vector_ptr_destroy(&separators);
throw std::runtime_error("recursive_decom failed with code " + std::to_string(err));
}
// 4. 将两个 iGraph 向量指针容器转换为 Python 列表
py::list atoms_pylist = igraph_vector_ptr_to_pylist(&atoms);
py::list separators_pylist = igraph_vector_ptr_to_pylist(&separators);
// 5. 释放两个容器及其所有元素(堆分配的 igraph_vector_int_t)
// 清理 atoms 容器
for (igraph_integer_t i = 0; i < igraph_vector_ptr_size(&atoms); i++) {
igraph_vector_int_destroy((igraph_vector_int_t *)VECTOR(atoms)[i]);
free(VECTOR(atoms)[i]);
}
igraph_vector_ptr_destroy(&atoms);
// 清理 separators 容器
for (igraph_integer_t i = 0; i < igraph_vector_ptr_size(&separators); i++) {
igraph_vector_int_destroy((igraph_vector_int_t *)VECTOR(separators)[i]);
free(VECTOR(separators)[i]);
}
igraph_vector_ptr_destroy(&separators);
// 6. 返回一个包含 (atoms, separators) 的 Python 列表或元组
// 这里使用 py::list 返回一个 [atoms_list, separators_list] 的结构
py::list result;
result.append(atoms_pylist);
result.append(separators_pylist);
return result;
}
py::list SAHR_wrapper(py::object graph_obj, const std::vector<int> &r_nodes) {
// 获取 igraph_t 指针
py::object graph_capsule = graph_obj.attr("__graph_as_capsule")();
igraph_t *graph = static_cast<igraph_t*>(PyCapsule_GetPointer(graph_capsule.ptr(), nullptr));
if (!graph) throw std::runtime_error("Invalid igraph capsule");
// 转换 r_nodes 为 int 数组
std::vector<int> r_vec = r_nodes; // 本地 copy
int *r_ptr = r_vec.data();
int r_size = static_cast<int>(r_vec.size());
// 输出参数
int *local2global = nullptr;
int result_size = 0;
igraph_error_t err = SAHR(graph, r_ptr, r_size, &local2global, &result_size);
if (err != IGRAPH_SUCCESS) {
if (local2global) free(local2global);
throw std::runtime_error("SAHR failed with code " + std::to_string(err));
}
// 转换结果为 Python list
py::list result;
for (int i = 0; i < result_size; i++) {
result.append(local2global[i]);
}
free(local2global);
return result;
}
PYBIND11_MODULE(decom_h, m) {
m.def("SAHR", &SAHR_wrapper,
py::arg("graph"),
py::arg("r_nodes"),
"Run SAHR algorithm and return remaining nodes' global indices as a list.");
m.doc() = "Example igraph C extension";
m.def("recursive_decom", &recursive_decom_wrapper,
py::arg("graph"), // 修改参数名以匹配 Python 接口的期望
py::arg("method") = "cmsa",
"Perform recursive decomposition of the graph using the specified method.\n"
"Returns: A list containing two lists: [atoms, separators].\n"
" - atoms: List of lists, where each inner list is an atom (a vertex set).\n"
" - separators: List of lists, where each inner list is a clique minimal separator (a vertex set).");
m.def("find_convex_hull", &find_convex_hull_wrapper,
py::arg("graph"),
py::arg("r_nodes"),
py::arg("method"),
"Perform CMSA decomposition on the graph starting from nodes r_nodes.");
m.def("components_forbidden", &components_forbidden_wrapper,
py::arg("graph"),
py::arg("forbidden_vertices") = std::vector<int>{},
"Calculate connected components excluding forbidden vertices, "
"and return each component with its boundary forbidden nodes.");
m.def("close_separator", &close_separator_wrapper,
py::arg("graph"),
py::arg("vertex"),
py::arg("forbidden_vertices") = std::vector<int>{},
"Calculate forbidden boundary reachable from vertex without traversing forbidden vertices.");
}