From c21991b8dbc2ea7a422e6daeb151e2583992610b Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 9 Feb 2026 17:48:17 +0100 Subject: [PATCH] Implementation of warmstart for network simplex can make use off precomputed potentials from sinkhorn or even related simplex --- ot/lp/EMD.h | 2 +- ot/lp/EMD_wrapper.cpp | 16 +- ot/lp/_network_simplex.py | 23 +- ot/lp/emd_wrap.pyx | 23 +- ot/lp/network_simplex_simple.h | 391 ++++++++++++++++++++++++++++- ot/lp/network_simplex_simple_omp.h | 52 +++- 6 files changed, 489 insertions(+), 18 deletions(-) diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 6f408ffeb..905a41851 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -29,7 +29,7 @@ enum ProblemType { MAX_ITER_REACHED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); +int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init); int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); int EMD_wrap_sparse( diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 6aa27897a..024cdc66e 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -22,7 +22,8 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - double* alpha, double* beta, double *cost, uint64_t maxIter) { + double* alpha, double* beta, double *cost, uint64_t maxIter, + double* alpha_init, double* beta_init) { // beware M and C are stored in row major C style!!! using namespace lemon; @@ -93,6 +94,19 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, } } + // Set warmstart potentials if provided + if (alpha_init != nullptr && beta_init != nullptr) { + // Compress warmstart potentials to only non-zero entries + std::vector alpha_compressed(n); + std::vector beta_compressed(m); + for (uint64_t i = 0; i < n; i++) { + alpha_compressed[i] = alpha_init[indI[i]]; + } + for (uint64_t j = 0; j < m; j++) { + beta_compressed[j] = beta_init[indJ[j]]; + } + net.setWarmstartPotentials(&alpha_compressed[0], &beta_compressed[0], (int)n, (int)m); + } // Solve the problem with the network simplex algorithm diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index ec06298bc..922d864e8 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -172,6 +172,7 @@ def emd( center_dual=True, numThreads=1, check_marginals=True, + warmstart_dual=None, ): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -237,6 +238,11 @@ def emd( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. + warmstart_dual: tuple of two arrays (alpha, beta), optional (default=None) + Warmstart dual potentials to accelerate convergence. Should be a tuple + (alpha, beta) where alpha is shape (ns,) and beta is shape (nt,). + These potentials are used to guide initial pivots in the network simplex. + Typically obtained from a previous EMD solve or Sinkhorn approximation. .. note:: The solver automatically detects sparse format using the backend's :py:meth:`issparse` method. For sparse inputs: @@ -373,8 +379,18 @@ def emd( a, b, edge_sources, edge_targets, edge_costs, numItermax ) else: + # Prepare warmstart if provided + alpha_init = None + beta_init = None + if warmstart_dual is not None: + alpha_init, beta_init = warmstart_dual + alpha_init = np.asarray(alpha_init, dtype=np.float64) + beta_init = np.asarray(beta_init, dtype=np.float64) + # Dense solver - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + G, cost, u, v, result_code = emd_c( + a, b, M, numItermax, numThreads, alpha_init, beta_init + ) # ============================================================================ # POST-PROCESS DUAL VARIABLES AND CREATE TRANSPORT PLAN @@ -513,6 +529,11 @@ def emd2( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. + warmstart_dual: tuple of two arrays (alpha, beta), optional (default=None) + Warmstart dual potentials to accelerate convergence. Should be a tuple + (alpha, beta) where alpha is shape (ns,) and beta is shape (nt,). + These potentials are used to guide initial pivots in the network simplex. + Typically obtained from a previous EMD solve or Sinkhorn approximation. .. note:: The solver automatically detects sparse format using the backend's :py:meth:`issparse` method. For sparse inputs: diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4f483dfe9..6ca907be1 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -20,7 +20,7 @@ import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil @@ -42,7 +42,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads, alpha_init=None, beta_init=None): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -81,6 +81,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod max_iter : uint64_t The maximum number of iterations before stopping the optimization algorithm if it has not converged. + alpha_init : (ns,) numpy.ndarray, float64, optional + Initial dual potentials for sources (warmstart) + beta_init : (nt,) numpy.ndarray, float64, optional + Initial dual potentials for targets (warmstart) Returns ------- @@ -101,6 +105,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + + # Warmstart potentials + cdef np.ndarray[double, ndim=1, mode="c"] alpha_init_c + cdef np.ndarray[double, ndim=1, mode="c"] beta_init_c + cdef double* alpha_init_ptr = NULL + cdef double* beta_init_ptr = NULL if not len(a): a=np.ones((n1,))/n1 @@ -110,11 +120,18 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod # init OT matrix G=np.zeros([n1, n2]) + + # Setup warmstart pointers if provided + if alpha_init is not None and beta_init is not None: + alpha_init_c = np.ascontiguousarray(alpha_init, dtype=np.float64) + beta_init_c = np.ascontiguousarray(beta_init, dtype=np.float64) + alpha_init_ptr = alpha_init_c.data + beta_init_ptr = beta_init_c.data # calling the function with nogil: if numThreads == 1: - result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, alpha_init_ptr, beta_init_ptr) else: result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads) return G, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index c9fef277e..80c1b8e3c 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -46,6 +46,8 @@ #include #include #include +#include +#include #ifdef HASHMAP #include #else @@ -345,6 +347,11 @@ namespace lemon { int _metric; // 0: sqeuclidean, 1: euclidean, 2: cityblock private: + // Warmstart data + CostVector _warmstart_pi; // Stores warmstart potentials + bool _warmstart_provided; // Flag indicating warmstart is available + bool _warmstart_tree_built; // Flag: tree was built by warmstartInit() + // Data for storing the spanning tree structure IntVector _parent; ArcVector _pred; @@ -768,6 +775,35 @@ namespace lemon { return *this; } + /// \brief Set initial dual potentials for warmstart. + /// + /// This function sets warmstart dual potentials that will be used + /// to guide the initial pivots in the network simplex algorithm. + /// The potentials should come from a previous solution (e.g., Sinkhorn or EMD). + /// + /// \param alpha Source node potentials (size n), where alpha[i] = -pi[source_i] + /// \param beta Target node potentials (size m), where beta[j] = +pi[target_j] + /// \param n Number of source nodes (compressed, non-zero supply) + /// \param m Number of target nodes (compressed, non-zero supply) + /// + /// Note: The sign convention matches EMD potential extraction: + /// alpha = -pi[source], beta = +pi[target] + /// The internal pi convention: reduced cost = cost[arc] + pi[source] - pi[target] + /// So: pi[source_i] = -alpha[i], pi[target_j] = beta[j] + void setWarmstartPotentials(const Cost* alpha, const Cost* beta, int n, int m) { + // Store warmstart potentials with correct sign AND node ID mapping. + // Graph source nodes: 0..n-1, stored at internal index _node_id(i) + // Graph target nodes: n..n+m-1, stored at internal index _node_id(n+j) + // _node_id(k) = _node_num - k - 1 (reversal mapping) + for (int i = 0; i < n; ++i) { + _warmstart_pi[_node_id(i)] = -alpha[i]; // pi[source] = -alpha + } + for (int j = 0; j < m; ++j) { + _warmstart_pi[_node_id(n + j)] = beta[j]; // pi[target] = +beta + } + _warmstart_provided = true; + } + /// @} /// \name Execution Control @@ -809,7 +845,15 @@ namespace lemon { /// \see ProblemType, PivotRule /// \see resetParams(), reset() ProblemType run() { - if (!init()) return INFEASIBLE; + + if (_warmstart_provided) { + if (!warmstartInit()) return INFEASIBLE; + _warmstart_tree_built = true; + } else { + if (!init()) return INFEASIBLE; + _warmstart_tree_built = false; + } + return start(); } @@ -857,6 +901,8 @@ namespace lemon { _cost[i] = 1; } _stype = GEQ; + _warmstart_provided = false; + _warmstart_tree_built = false; // Reset warmstart flag return *this; } @@ -901,6 +947,7 @@ namespace lemon { _supply.resize(all_node_num); _flow.resize(max_arc_num); _pi.resize(all_node_num); + _warmstart_pi.resize(all_node_num); // Initialize warmstart storage _parent.resize(all_node_num); _pred.resize(all_node_num); @@ -1069,6 +1116,341 @@ namespace lemon { private: + // WARMSTART: Build spanning tree from dual potentials + bool warmstartInit() { + if (_node_num == 0) return false; + + // Check supply balance + _sum_supply = 0; + for (int i = 0; i != _node_num; ++i) { + _sum_supply += _supply[i]; + } + if (fabs(_sum_supply) > _EPSILON) return false; + _sum_supply = 0; + + // Compute ART_COST (same as init()) + Cost ART_COST; + if (std::numeric_limits::is_exact) { + ART_COST = std::numeric_limits::max() / 2 + 1; + } else { + ART_COST = 0; + for (ArcsType i = 0; i != _arc_num; ++i) { + if (_cost[i] > ART_COST) ART_COST = _cost[i]; + } + ART_COST = (ART_COST + 1) * _node_num; + } + + // Initialize all real arcs as STATE_LOWER with zero flow + for (ArcsType i = 0; i != _arc_num; ++i) { + _state[i] = STATE_LOWER; + } + + // STEP 1: Build MST using partial sort + union-find + int tree_edges = 0; + std::vector tree_arcs; + tree_arcs.reserve(_node_num); + + { + std::vector arc_order(_arc_num); + std::vector arc_absrc(_arc_num); + for (ArcsType e = 0; e < _arc_num; ++e) { + arc_order[e] = e; + arc_absrc[e] = fabs(_cost[e] + _warmstart_pi[_source[e]] - _warmstart_pi[_target[e]]); + } + + ArcsType K = std::min((ArcsType)(4 * _node_num), _arc_num); + if (K < _arc_num) { + std::nth_element(arc_order.begin(), arc_order.begin() + K, arc_order.end(), + [&](ArcsType a, ArcsType b) { + return arc_absrc[a] < arc_absrc[b]; + }); + } + + std::sort(arc_order.begin(), arc_order.begin() + K, + [&](ArcsType a, ArcsType b) { + return arc_absrc[a] < arc_absrc[b]; + }); + + std::vector uf_parent(_node_num); + std::vector uf_rank(_node_num, 0); + for (int i = 0; i < _node_num; ++i) uf_parent[i] = i; + + for (ArcsType idx = 0; idx < K && tree_edges < _node_num - 1; ++idx) { + ArcsType e = arc_order[idx]; + int s = _source[e]; + int t = _target[e]; + int rs = s, rt = t; + while (uf_parent[rs] != rs) { uf_parent[rs] = uf_parent[uf_parent[rs]]; rs = uf_parent[rs]; } + while (uf_parent[rt] != rt) { uf_parent[rt] = uf_parent[uf_parent[rt]]; rt = uf_parent[rt]; } + if (rs == rt) continue; + if (uf_rank[rs] < uf_rank[rt]) std::swap(rs, rt); + uf_parent[rt] = rs; + if (uf_rank[rs] == uf_rank[rt]) uf_rank[rs]++; + tree_arcs.push_back(e); + tree_edges++; + } + + if (tree_edges < _node_num - 1) { + for (ArcsType idx = K; idx < _arc_num && tree_edges < _node_num - 1; ++idx) { + ArcsType e = arc_order[idx]; + int s = _source[e]; + int t = _target[e]; + int rs = s, rt = t; + while (uf_parent[rs] != rs) { uf_parent[rs] = uf_parent[uf_parent[rs]]; rs = uf_parent[rs]; } + while (uf_parent[rt] != rt) { uf_parent[rt] = uf_parent[uf_parent[rt]]; rt = uf_parent[rt]; } + if (rs == rt) continue; + if (uf_rank[rs] < uf_rank[rt]) std::swap(rs, rt); + uf_parent[rt] = rs; + if (uf_rank[rs] == uf_rank[rt]) uf_rank[rs]++; + tree_arcs.push_back(e); + tree_edges++; + } + } + } + + std::vector tree_adj_deg(_node_num, 0); + for (int k = 0; k < tree_edges; ++k) { + ArcsType e = tree_arcs[k]; + tree_adj_deg[_source[e]]++; + tree_adj_deg[_target[e]]++; + } + std::vector tree_adj_start(_node_num + 1, 0); + for (int i = 0; i < _node_num; ++i) { + tree_adj_start[i + 1] = tree_adj_start[i] + tree_adj_deg[i]; + } + int total_adj = tree_adj_start[_node_num]; + std::vector tree_adj_node(total_adj); + std::vector tree_adj_arc(total_adj); + std::vector tree_adj_pos(_node_num, 0); + for (int k = 0; k < tree_edges; ++k) { + ArcsType e = tree_arcs[k]; + int s = _source[e], t = _target[e]; + int ps = tree_adj_start[s] + tree_adj_pos[s]++; + tree_adj_node[ps] = t; + tree_adj_arc[ps] = e; + int pt = tree_adj_start[t] + tree_adj_pos[t]++; + tree_adj_node[pt] = s; + tree_adj_arc[pt] = e; + } + + // STEP 2: Set up artificial arcs + _search_arc_num = _arc_num; + _all_arc_num = _arc_num + _node_num; + _root = _node_num; + + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _state[e] = STATE_TREE; + if (_supply[u] >= 0) { + _source[e] = u; + _target[e] = _root; + _cost[e] = 0; + _flow[e] = _supply[u]; + } else { + _source[e] = _root; + _target[e] = u; + _cost[e] = ART_COST; + _flow[e] = -_supply[u]; + } + } + + // Root node setup + _parent[_root] = -1; + _pred[_root] = -1; + _supply[_root] = -_sum_supply; + _pi[_root] = 0; + + // STEP 3: BFS from root to build tree structure + std::vector is_rep(_node_num, false); + std::vector visited(_node_num, false); + + for (int u = 0; u < _node_num; ++u) { + if (visited[u]) continue; + is_rep[u] = true; + + _parent[u] = _root; + _pred[u] = _arc_num + u; + _forward[u] = (_supply[u] >= 0); // same as init() + _state[_arc_num + u] = STATE_TREE; + visited[u] = true; + + std::queue bfs_queue; + bfs_queue.push(u); + while (!bfs_queue.empty()) { + int v = bfs_queue.front(); + bfs_queue.pop(); + for (int k = tree_adj_start[v]; k < tree_adj_start[v + 1]; ++k) { + int w = tree_adj_node[k]; + ArcsType arc_e = tree_adj_arc[k]; + if (visited[w]) continue; + visited[w] = true; + + _parent[w] = v; + _pred[w] = arc_e; + _state[arc_e] = STATE_TREE; + _forward[w] = (_source[arc_e] == w); + + _state[_arc_num + w] = STATE_LOWER; + _flow[_arc_num + w] = 0; + + bfs_queue.push(w); + } + } + } + + // STEP 4: Build thread (preorder traversal) + { + std::vector> children(_node_num + 1); + for (int u = 0; u < _node_num; ++u) { + children[_parent[u]].push_back(u); + } + + std::vector preorder; + preorder.reserve(_node_num + 1); + std::stack dfs_stack; + dfs_stack.push(_root); + while (!dfs_stack.empty()) { + int v = dfs_stack.top(); + dfs_stack.pop(); + preorder.push_back(v); + for (int i = (int)children[v].size() - 1; i >= 0; --i) { + dfs_stack.push(children[v][i]); + } + } + + for (int i = 0; i < (int)preorder.size() - 1; ++i) { + _thread[preorder[i]] = preorder[i + 1]; + } + _thread[preorder.back()] = preorder[0]; + + for (int u = 0; u <= _node_num; ++u) { + _rev_thread[_thread[u]] = u; + } + + for (int u = 0; u <= _node_num; ++u) { + _succ_num[u] = 1; + } + for (int i = (int)preorder.size() - 1; i > 0; --i) { + int u = preorder[i]; + _succ_num[_parent[u]] += _succ_num[u]; + } + + std::vector pos(_node_num + 1); + for (int i = 0; i < (int)preorder.size(); ++i) { + pos[preorder[i]] = i; + } + for (int i = 0; i < (int)preorder.size(); ++i) { + int u = preorder[i]; + _last_succ[u] = preorder[pos[u] + _succ_num[u] - 1]; + } + } + + // STEP 5: Compute flows on tree arcs + { + std::vector net(_node_num + 1); + for (int u = 0; u <= _node_num; ++u) { + net[u] = _supply[u]; + } + + std::vector preorder; + preorder.reserve(_node_num + 1); + int cur = _root; + for (int i = 0; i <= _node_num; ++i) { + preorder.push_back(cur); + cur = _thread[cur]; + } + + int ejected = 0; + for (int i = (int)preorder.size() - 1; i > 0; --i) { + int u = preorder[i]; + ArcsType e = _pred[u]; + + Value f = _forward[u] ? net[u] : -net[u]; + + if (f >= 0) { + _flow[e] = f; + net[_parent[u]] += net[u]; + } else { + if (e < _arc_num) { + _state[e] = STATE_LOWER; + _flow[e] = 0; + } + // Reconnect u to root via artificial arc + ArcsType art_e = _arc_num + u; + _parent[u] = _root; + _pred[u] = art_e; + _forward[u] = (_source[art_e] == u); + _state[art_e] = STATE_TREE; + + Value art_f = _forward[u] ? net[u] : -net[u]; + _flow[art_e] = art_f >= 0 ? art_f : -art_f; + if (art_f < 0) { + _forward[u] = !_forward[u]; + _flow[art_e] = -art_f; + } + + net[_root] += net[u]; + ejected++; + } + } + if (ejected > 0) { + std::vector> children2(_node_num + 1); + for (int u = 0; u < _node_num; ++u) { + children2[_parent[u]].push_back(u); + } + // DFS preorder + std::vector preorder2; + preorder2.reserve(_node_num + 1); + std::stack dfs2; + dfs2.push(_root); + while (!dfs2.empty()) { + int v = dfs2.top(); dfs2.pop(); + preorder2.push_back(v); + for (int j = (int)children2[v].size() - 1; j >= 0; --j) { + dfs2.push(children2[v][j]); + } + } + for (int i = 0; i < (int)preorder2.size() - 1; ++i) { + _thread[preorder2[i]] = preorder2[i + 1]; + } + _thread[preorder2.back()] = preorder2[0]; + for (int u = 0; u <= _node_num; ++u) { + _rev_thread[_thread[u]] = u; + } + for (int u = 0; u <= _node_num; ++u) _succ_num[u] = 1; + for (int i = (int)preorder2.size() - 1; i > 0; --i) { + _succ_num[_parent[preorder2[i]]] += _succ_num[preorder2[i]]; + } + std::vector pos2(_node_num + 1); + for (int i = 0; i < (int)preorder2.size(); ++i) pos2[preorder2[i]] = i; + for (int i = 0; i < (int)preorder2.size(); ++i) { + int u = preorder2[i]; + _last_succ[u] = preorder2[pos2[u] + _succ_num[u] - 1]; + } + } + } + + // STEP 6: Compute potentials from the final tree + { + _pi[_root] = 0; + int u = _thread[_root]; + while (u != _root) { + ArcsType e = _pred[u]; + int v = _parent[u]; + if (_forward[u]) { + _pi[u] = _pi[v] - _cost[e]; + } else { + _pi[u] = _pi[v] + _cost[e]; + } + u = _thread[u]; + } + } + + // Initialize in_arc to a valid value + in_arc = 0; + + return true; + } + // Initialize internal data structures bool init() { if (_node_num == 0) return false; @@ -1461,7 +1843,6 @@ namespace lemon { if (_sum_supply >= 0) { if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { // Perform a reverse graph search from the sink to the source - //typename GR::template NodeMap reached(_graph, false); BoolVector reached(_node_num, false); Node s = supply_nodes[0], t = demand_nodes[0]; std::vector stack; @@ -1549,8 +1930,10 @@ namespace lemon { PivotRuleImpl pivot(*this); ProblemType retVal = OPTIMAL; - // Perform heuristic initial pivots - if (!initialPivots()) return UNBOUNDED; + // Perform heuristic initial pivots (skip if warmstart tree was built) + if (!_warmstart_tree_built) { + if (!initialPivots()) return UNBOUNDED; + } uint64_t iter_number = 0; //pivot.setDantzig(true); diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h index d8ef672e3..98131850a 100644 --- a/ot/lp/network_simplex_simple_omp.h +++ b/ot/lp/network_simplex_simple_omp.h @@ -371,6 +371,10 @@ namespace lemon_omp { CostVector _pi; + // Warmstart data + CostVector _warmstart_pi; // Stores warmstart potentials + bool _warmstart_provided; // Flag indicating warmstart is available + // Data for storing the spanning tree structure IntVector _parent; ArcVector _pred; @@ -785,6 +789,30 @@ namespace lemon_omp { return *this; } + /// \brief Set initial dual potentials for warmstart. + /// + /// This function sets warmstart dual potentials that will be used + /// to guide the initial pivots in the network simplex algorithm. + /// The potentials should come from a previous solution (e.g., Sinkhorn or EMD). + /// + /// \param alpha Source node potentials (size n) + /// \param beta Target node potentials (size m) + /// \param n Number of source nodes + /// \param m Number of target nodes + /// + /// Note: The sign convention matches EMD potential extraction: + /// alpha = -pi[source], beta = +pi[target] + void setWarmstartPotentials(const Cost* alpha, const Cost* beta, int n, int m) { + // Store warmstart potentials with correct sign conversion + for (int i = 0; i < n; ++i) { + _warmstart_pi[i] = -alpha[i]; // Negate alpha to convert back to internal pi + } + for (int j = 0; j < m; ++j) { + _warmstart_pi[n + j] = beta[j]; // Beta is already correct sign + } + _warmstart_provided = true; + } + /// @} /// \name Execution Control @@ -830,6 +858,14 @@ namespace lemon_omp { std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; #endif if (!init()) return INFEASIBLE; + + // Apply warmstart potentials after init() if provided + if (_warmstart_provided) { + for (int i = 0; i < _node_num; ++i) { + _pi[i] = _warmstart_pi[i]; + } + } + #if DEBUG_LVL>0 std::cout << "Init done, starting iterations\n"; #endif @@ -877,14 +913,13 @@ namespace lemon_omp { for (int i = 0; i != _node_num; ++i) { _supply[i] = 0; } - for (ArcsType i = 0; i != _arc_num; ++i) { - _cost[i] = 1; - } - _stype = GEQ; - return *this; - } - - + for (ArcsType i = 0; i != _arc_num; ++i) { + _cost[i] = 1; + } + _stype = GEQ; + _warmstart_provided = false; // Reset warmstart flag + return *this; + } /// \brief Reset the internal data structures and all the parameters /// that have been given before. /// @@ -919,6 +954,7 @@ namespace lemon_omp { _supply.resize(all_node_num); _flow.resize(max_arc_num); _pi.resize(all_node_num); + _warmstart_pi.resize(all_node_num); // Initialize warmstart storage _parent.resize(all_node_num); _pred.resize(all_node_num);