diff --git a/include/meta/caching/maps/locking_map.h b/include/meta/caching/maps/locking_map.h index dfe009278..78e06f099 100644 --- a/include/meta/caching/maps/locking_map.h +++ b/include/meta/caching/maps/locking_map.h @@ -14,6 +14,7 @@ #include #include "meta/util/optional.h" +#include "meta/hashing/hash.h" namespace meta { @@ -76,10 +77,10 @@ class locking_map util::optional find(const Key& key) const; /// iterator type for locking_maps - using iterator = typename std::unordered_map::iterator; + using iterator = typename std::unordered_map>::iterator; /// const_iterator type for locking_maps using const_iterator = - typename std::unordered_map::const_iterator; + typename std::unordered_map>::const_iterator; /** * @return an iterator to the beginning of the map @@ -103,7 +104,7 @@ class locking_map private: /// the underlying map used for storage - std::unordered_map map_; + std::unordered_map> map_; /// the mutex that synchronizes accesses into the map mutable std::mutex mutables_; }; diff --git a/include/meta/caching/shard_cache.h b/include/meta/caching/shard_cache.h index 578b888df..b21392a88 100644 --- a/include/meta/caching/shard_cache.h +++ b/include/meta/caching/shard_cache.h @@ -85,7 +85,8 @@ class generic_shard_cache * The hash function used for determining which shard a key * belongs to. */ - std::hash hasher_; +// std::hash hasher_; + hashing::hash<> hasher_; }; /** diff --git a/include/meta/embeddings/wmd/min_cost_flow.h b/include/meta/embeddings/wmd/min_cost_flow.h new file mode 100644 index 000000000..842b540df --- /dev/null +++ b/include/meta/embeddings/wmd/min_cost_flow.h @@ -0,0 +1,122 @@ +/** + * @file min_cost_flow.h + * @author lolik111 + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef FAST_EMD_MIN_COST_FLOW_H +#define FAST_EMD_MIN_COST_FLOW_H + +#include +#include +#include +#include +#include + +namespace meta +{ +namespace embeddings +{ +template +struct edge; + +template +struct edge_weighted; + +template +class min_cost_flow +{ + + public: + NumT emd_hat(const std::vector& supply, + const std::vector& demand, + const std::vector>& cost); + + // e - supply(positive) and demand(negative). + // c[i] - edges that goes from node i. first is the second nod + // x - the flow is returned in it + NumT compute_min_cost_flow(std::vector& e, + const std::vector>>& c, + std::vector>>& x); + + private: + size_t _num_nodes; + std::vector _nodes_to_demand; + + template + static T integral_emd_hat(const std::vector& supply, + const std::vector& demand, + const std::vector>& cost); + + void compute_shortest_path( + std::vector& d, std::vector& prev, + + size_t from, std::vector>>& cost_forward, + std::vector>>& cost_backward, + + const std::vector& e, size_t& l); + + void heap_decrease_key(std::vector>& demand, + std::vector& nodes_to_demand, size_t v, + NumT alt); + + void heap_remove_first(std::vector>& demand, + std::vector& nodes_to_demand); + + void heapify(std::vector>& demand, + std::vector& nodes_to_demand, size_t i); + + void swap_heap(std::vector>& demand, + std::vector& nodes_to_demand, size_t i, size_t j); + + size_t LEFT(size_t i) + { + return 2 * (i + 1) - 1; + } + + size_t RIGHT(size_t i) + { + return 2 * (i + 1); // 2*(i+1)+1-1 + } + + size_t PARENT(size_t i) + { + return (i - 1) / 2; + } +}; +} +} + +#include "min_cost_flow.tcc" + +#endif // FAST_EMD_MIN_COST_FLOW_H + +// Copyright (c) 2009-2012, Ofir Pele +// All rights reserved. + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of the The Hebrew University of Jerusalem nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/include/meta/embeddings/wmd/min_cost_flow.tcc b/include/meta/embeddings/wmd/min_cost_flow.tcc new file mode 100644 index 000000000..37f9bf512 --- /dev/null +++ b/include/meta/embeddings/wmd/min_cost_flow.tcc @@ -0,0 +1,667 @@ +/** + * @file min_cost_flow.tcc + * @author lolik111 + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#include +#include + +#include "min_cost_flow.h" + +namespace meta +{ +namespace embeddings +{ +template +struct edge +{ + edge(size_t to = 0, T cost = 0) : _to(to), _cost(cost) + { + } + + size_t _to; + T _cost; +}; + +template +struct edge_weighted +{ + edge_weighted(size_t to, T cost, T amount) + : _to(to), _cost(cost), _amount(amount) + { + } + + size_t _to; + T _cost; + T _amount; +}; + +template +NumT min_cost_flow::compute_min_cost_flow( + std::vector& e, const std::vector>>& c, + std::vector>>& x) +{ + + assert(e.size() == c.size()); + assert(x.size() == c.size()); + + _num_nodes = e.size(); + _nodes_to_demand.resize(_num_nodes); + + // reduced costs for forward edges (c[i,j]-pi[i]+pi[j]) + // Note that for forward edges the residual capacity is infinity + std::vector>> r_cost_forward(_num_nodes); + + // reduced costs and capacity for backward edges (c[j,i]-pi[j]+pi[i]) + // Since the flow at the beginning is 0, the residual capacity is also zero + std::vector>> r_cost_cap_backward(_num_nodes); + + for (size_t from = 0; from < _num_nodes; ++from) + { + for (auto it = c[from].begin(); it != c[from].end(); ++it) + { + // init flow + x[from].push_back(edge_weighted(it->_to, it->_cost, 0)); + x[it->_to].push_back(edge_weighted(from, -it->_cost, 0)); + + r_cost_forward[from].push_back(edge(it->_to, it->_cost)); + r_cost_cap_backward[it->_to].push_back( + edge_weighted(from, -it->_cost, 0)); + } + } + + // Max supply + NumT U = 0; + for (size_t i = 0; i < _num_nodes; ++i) + { + if (e[i] > U) + U = e[i]; + } + + std::vector d(_num_nodes); + std::vector prev(_num_nodes); + NumT delta = 1; + while (true) + { // until we break when S or T is empty + + NumT max_supply = 0; + size_t k = 0; + for (size_t i = 0; i < _num_nodes; ++i) + { + if (e[i] > 0) + { + if (max_supply < e[i]) + { + max_supply = e[i]; + k = i; + } + } + } + if (max_supply == 0) + break; + delta = max_supply; + + size_t l; + compute_shortest_path(d, prev, k, r_cost_forward, r_cost_cap_backward, + e, l); + + // find delta (minimum on the path from k to l) + size_t to = l; + do + { + size_t from = prev[to]; + assert(from != to); + + // residual + auto itccb = r_cost_cap_backward[from].begin(); + while ((itccb != r_cost_cap_backward[from].end()) + && (itccb->_to != to)) + { + ++itccb; + } + if (itccb != r_cost_cap_backward[from].end()) + { + if (itccb->_amount < delta) + delta = itccb->_amount; + } + + to = from; + } while (to != k); + + // augment delta flow from k to l (backwards actually...) + to = l; + do + { + size_t from = prev[to]; + assert(from != to); + + auto itx = x[from].begin(); + while (itx->_to != to) + { + ++itx; + } + itx->_amount += delta; + + // update residual for backward edges + auto itccb = r_cost_cap_backward[to].begin(); + while ((itccb != r_cost_cap_backward[to].end()) + && (itccb->_to != from)) + { + ++itccb; + } + if (itccb != r_cost_cap_backward[to].end()) + { + itccb->_amount += delta; + } + itccb = r_cost_cap_backward[from].begin(); + while ((itccb != r_cost_cap_backward[from].end()) + && (itccb->_to != to)) + { + ++itccb; + } + if (itccb != r_cost_cap_backward[from].end()) + { + itccb->_amount -= delta; + } + + // update e + e[to] += delta; + e[from] -= delta; + + to = from; + } while (to != k); + } + + // compute distance from x + NumT dist = 0; + for (size_t from = 0; from < _num_nodes; ++from) + { + for (auto it = x[from].begin(); it != x[from].end(); ++it) + { + dist += (it->_cost * it->_amount); + } + } + + return dist; +} + +template +void min_cost_flow::compute_shortest_path( + std::vector& d, std::vector& prev, size_t from, + std::vector>>& cost_forward, + std::vector>>& cost_backward, + const std::vector& e, size_t& l) +{ + // Making heap (all inf except 0, so we are saving comparisons...) + std::vector> demand(_num_nodes); + + demand[0]._to = from; + _nodes_to_demand[from] = 0; + demand[0]._cost = 0; + + size_t j = 1; + for (size_t i = 0; i < from; ++i) + { + demand[j]._to = i; + _nodes_to_demand[i] = j; + demand[j]._cost = std::numeric_limits::max(); + ++j; + } + + for (size_t i = from + 1; i < _num_nodes; ++i) + { + demand[j]._to = i; + _nodes_to_demand[i] = j; + demand[j]._cost = std::numeric_limits::max(); + ++j; + } + + // main loop + std::vector final_nodes_flg(_num_nodes, false); + do + { + size_t u = demand[0]._to; + + d[u] = demand[0]._cost; // final distance + final_nodes_flg[u] = true; + if (e[u] < 0) + { + l = u; + break; + } + + heap_remove_first(demand, _nodes_to_demand); + + // neighbors of capacity + for (auto it = cost_forward[u].begin(); it != cost_forward[u].end(); + ++it) + { + assert(it->_cost >= 0); + NumT alt = d[u] + it->_cost; + size_t v = it->_to; + if ((_nodes_to_demand[v] < demand.size()) + && (alt < demand[_nodes_to_demand[v]]._cost)) + { + heap_decrease_key(demand, _nodes_to_demand, v, alt); + prev[v] = u; + } + } + + for (auto it = cost_backward[u].begin(); it != cost_backward[u].end(); + ++it) + { + if (it->_amount > 0) + { + assert(it->_cost >= 0); + NumT alt = d[u] + it->_cost; + size_t v = it->_to; + if ((_nodes_to_demand[v] < demand.size()) + && (alt < demand[_nodes_to_demand[v]]._cost)) + { + heap_decrease_key(demand, _nodes_to_demand, v, alt); + prev[v] = u; + } + } + } + } while (!demand.empty()); + + // reduced costs for forward edges (cost[i,j]-pi[i]+pi[j]) + for (size_t node_from = 0; node_from < _num_nodes; ++node_from) + { + + for (auto it = cost_forward[node_from].begin(); + it != cost_forward[node_from].end(); ++it) + { + if (final_nodes_flg[node_from]) + { + it->_cost += d[node_from] - d[l]; + } + if (final_nodes_flg[it->_to]) + { + it->_cost -= d[it->_to] - d[l]; + } + } + } + + // reduced costs and capacity for backward edges (c[j,i]-pi[j]+pi[i]) + for (size_t node_from = 0; node_from < _num_nodes; ++node_from) + { + for (auto it = cost_backward[node_from].begin(); + it != cost_backward[node_from].end(); ++it) + { + if (final_nodes_flg[node_from]) + { + it->_cost += d[node_from] - d[l]; + } + if (final_nodes_flg[it->_to]) + { + it->_cost -= d[it->_to] - d[l]; + } + } + } +} + +template +void min_cost_flow::heap_decrease_key( + std::vector>& demand, std::vector& nodes_to_demand, + size_t v, NumT alt) +{ + size_t i = nodes_to_demand[v]; + demand[i]._cost = alt; + while (i > 0 && demand[PARENT(i)]._cost > demand[i]._cost) + { + swap_heap(demand, nodes_to_demand, i, PARENT(i)); + i = PARENT(i); + } +} + +template +void min_cost_flow::heap_remove_first( + std::vector>& demand, std::vector& nodes_to_demand) +{ + swap_heap(demand, nodes_to_demand, 0, demand.size() - 1); + demand.pop_back(); + heapify(demand, nodes_to_demand, 0); +} + +template +void min_cost_flow::heapify(std::vector>& demand, + std::vector& nodes_to_demand, + size_t i) +{ + do + { + // TODO: change to loop + size_t l = LEFT(i); + size_t r = RIGHT(i); + size_t smallest; + if ((l < demand.size()) && (demand[l]._cost < demand[i]._cost)) + { + smallest = l; + } + else + { + smallest = i; + } + if ((r < demand.size()) && (demand[r]._cost < demand[smallest]._cost)) + { + smallest = r; + } + + if (smallest == i) + return; + + swap_heap(demand, nodes_to_demand, i, smallest); + i = smallest; + + } while (true); +} + +template +void min_cost_flow::swap_heap(std::vector>& demand, + std::vector& nodes_to_demand, + size_t i, size_t j) +{ + edge tmp = demand[i]; + demand[i] = demand[j]; + demand[j] = tmp; + nodes_to_demand[demand[j]._to] = j; + nodes_to_demand[demand[i]._to] = i; +} + +template +NumT min_cost_flow::emd_hat(const std::vector& supply, + const std::vector& demand, + const std::vector>& cost) +{ + if (std::is_integral::value && std::is_signed::value) + { + return integral_emd_hat(supply, demand, cost); + } + else + { + + const double mult_factor = 1000000; + + // Constructing the input + const size_t n = supply.size(); + std::vector i_supply(n); + std::vector i_demand(n); + std::vector> i_cost(n, std::vector(n)); + + // Converting to uint64_t + double sum_supply = 0.0; + double sum_demand = 0.0; + double max_cost = cost[0][0]; + for (size_t i = 0; i < n; ++i) + { + sum_supply += supply[i]; + sum_demand += demand[i]; + for (size_t j = 0; j < n; ++j) + { + if (cost[i][j] > max_cost) + max_cost = cost[i][j]; + } + } + double max_sum = std::max(sum_supply, sum_demand); + double supply_demand_norm_factor = mult_factor / max_sum; + if (max_cost < 1e-12){ + return 0.0; + } + double cost_norm_factor = mult_factor / max_cost; + for (size_t i = 0; i < n; ++i) + { + i_supply[i] = static_cast( + floor(supply[i] * supply_demand_norm_factor + 0.5)); + i_demand[i] = static_cast( + floor(demand[i] * supply_demand_norm_factor + 0.5)); + for (size_t j = 0; j < n; ++j) + { + i_cost[i][j] = static_cast( + floor(cost[i][j] * cost_norm_factor + 0.5)); + } + } + + // computing distance + double dist = integral_emd_hat(i_supply, i_demand, i_cost); + + dist = dist / supply_demand_norm_factor; + dist = dist / cost_norm_factor; + + return dist; + } +} + +template +template +T min_cost_flow::integral_emd_hat( + const std::vector& supply_c, const std::vector& demand_c, + const std::vector>& cost_c) +{ + size_t n = supply_c.size(); + assert(demand_c.size() == n); + + // Ensuring that the supplier - supply, have more mass. + std::vector supply; + std::vector demand; + std::vector> cost(cost_c); + T abs_diff_sum_supply_sum_denamd; + T sum_supply = 0; + T sum_demand = 0; + for (size_t i = 0; i < n; ++i) + { + sum_supply += supply_c[i]; + sum_demand += demand_c[i]; + } + + if (sum_demand > sum_supply) + { + supply = demand_c; + demand = supply_c; + // transpose cost + for (size_t i = 0; i < n; ++i) + { + for (size_t j = 0; j < n; ++j) + { + cost[i][j] = cost_c[j][i]; + } + } + abs_diff_sum_supply_sum_denamd = sum_demand - sum_supply; + } + else + { + supply = supply_c; + demand = demand_c; + abs_diff_sum_supply_sum_denamd = sum_supply - sum_demand; + } + + // creating the b vector that contains all vertexes + std::vector b(2 * n + 2); + const size_t threshold_node = 2 * n; + const size_t artificial_node = 2 * n + 1; // need to be last ! + for (size_t i = 0; i < n; ++i) + { + b[i] = supply[i]; + b[i + n] = demand[i]; + } + + // remark*) Deficit of the extra mass, as mass that flows to the threshold + // node can be absorbed from all sources with cost zero + // This makes sum of b zero. + b[threshold_node] = -abs_diff_sum_supply_sum_denamd; + b[artificial_node] = 0; + + T max_cost = 0; + for (size_t i = 0; i < n; ++i) + { + for (size_t j = 0; j < n; ++j) + { + assert(cost[i][j] >= 0); + if (cost[i][j] > max_cost) + max_cost = cost[i][j]; + } + } + + std::set sources_that_flow_not_only_to_thresh; + std::set sinks_that_get_flow_not_only_from_thresh; + T pre_flow_cost = 0; + + // regular edges between sinks and sources without threshold edges + std::vector>> c(b.size()); + { + for (size_t i = 0; i < n; ++i) + { + if (b[i] == 0) + continue; + { + for (size_t j = 0; j < n; ++j) + { + if (b[j + n] == 0) + continue; + if (cost[i][j] == max_cost) + continue; + c[i].push_back(edge(j + n, cost[i][j])); + + // checking which are not isolated + sources_that_flow_not_only_to_thresh.insert(i); + sinks_that_get_flow_not_only_from_thresh.insert(j + n); + } + } + } + } + + // converting all sinks to negative + for (size_t i = n; i < 2 * n; ++i) + { + b[i] = -b[i]; + } + + // add edges from/to threshold node, + // note that costs are reversed to the paper (see also remark* above) + // It is important that it will be this way because of remark* above. + for (size_t i = 0; i < n; ++i) + { + c[i].push_back(edge(threshold_node, 0)); + c[threshold_node].push_back(edge(i + n, max_cost)); + } + + // artificial arcs - Note the restriction that only one edge i,j is + // artificial so I ignore it... + for (size_t i = 0; i < artificial_node; ++i) + { + c[i].push_back(edge(artificial_node, max_cost + 1)); + c[artificial_node].push_back(edge(i, max_cost + 1)); + } + + // remove nodes with supply demand of 0 + // and vertices that are connected only to the + // threshold vertex + int current_node_name = 0; + // Note here it should be vector and not vector + // as I'm using -1 as a special flag !!! + const int remove_node_flag = -1; + std::vector nodes_new_names(b.size(), remove_node_flag); + std::vector nodes_old_names; + nodes_old_names.reserve(b.size()); + + for (size_t i = 0; i < n * 2; ++i) + { + if (b[i] != 0) + { + if (sources_that_flow_not_only_to_thresh.find(i) + != sources_that_flow_not_only_to_thresh.end() + || sinks_that_get_flow_not_only_from_thresh.find(i) + != sinks_that_get_flow_not_only_from_thresh.end()) + { + nodes_new_names[i] = current_node_name; + nodes_old_names.push_back(i); + ++current_node_name; + } + else + { + if (i >= n) + { // sink + pre_flow_cost -= (b[i] * max_cost); + } + b[threshold_node] += b[i]; // add mass(i=n) + } + } + } + + nodes_new_names[threshold_node] = current_node_name; + nodes_old_names.push_back(threshold_node); + ++current_node_name; + nodes_new_names[artificial_node] = current_node_name; + nodes_old_names.push_back(artificial_node); + ++current_node_name; + + std::vector bb(current_node_name); + size_t j = 0; + for (size_t i = 0; i < b.size(); ++i) + { + if (nodes_new_names[i] != remove_node_flag) + { + bb[j] = b[i]; + ++j; + } + } + + std::vector>> cc(bb.size()); + for (size_t i = 0; i < c.size(); ++i) + { + if (nodes_new_names[i] == remove_node_flag) + continue; + for (auto it = c[i].begin(); it != c[i].end(); ++it) + { + if (nodes_new_names[it->_to] != remove_node_flag) + { + cc[nodes_new_names[i]].push_back( + edge(nodes_new_names[it->_to], it->_cost)); + } + } + } + + min_cost_flow mcf; + T my_dist; + std::vector>> flows(bb.size()); + + T mcf_dist = mcf.compute_min_cost_flow(bb, cc, flows); + + my_dist = pre_flow_cost + // pre-flowing on cases where it was possible + mcf_dist; // solution of the transportation problem + + return my_dist; +} +} +} + +// Copyright (c) 2009-2012, Ofir Pele +// All rights reserved. + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of the The Hebrew University of Jerusalem nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/include/meta/embeddings/wmd/wm_distance.h b/include/meta/embeddings/wmd/wm_distance.h new file mode 100644 index 000000000..3454a0fb9 --- /dev/null +++ b/include/meta/embeddings/wmd/wm_distance.h @@ -0,0 +1,132 @@ +/** + * @file wm_distance.h + * @author lolik111 + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +#ifndef META_EMD_H +#define META_EMD_H + +#include +#include + +#include "meta/caching/all.h" +#include "meta/embeddings/word_embeddings.h" +#include "meta/math/vector.h" + +namespace meta +{ + +namespace embeddings +{ +/** + * Struct representing one document in the wmd processing + */ +struct emb_document +{ + size_t n_terms; + std::vector ids; + std::vector weights; +}; +/** + * Class, providing methods to calculate distance between two documents + * in a sense of word-embedding representation + */ +class wm_distance +{ + public: + using metric_type + = std::function&, + const util::array_view&)>; + + wm_distance( + std:: + shared_ptr, + double>> + cache_, + std::shared_ptr embeddings, + metric_type metric, size_t nthreads = 1); + + /** + * Calculates distance based on type of algorithm + * @param algorithm_type type of the algorithm: "wcd", "rwmd" or "emd" + * @param doc1 + * @param doc2 + * @return distance between two documents + */ + double score(const std::string algorithm_type, const emb_document& doc1, + const emb_document& doc2); + + /** + * Calculates original word mover's distance (based on Matt J. Kusner's + * paper) + * Uses Orif Pele Fast EMD algorithm + * @param doc1 + * @param doc2 + * @return distance between two documents + */ + double emd(const emb_document& doc1, const emb_document& doc2); + /** + * Calculates relaxed EM distance + * @param doc1 + * @param doc2 + * @return distance between two documents + */ + double emd_relaxed(const emb_document& doc1, const emb_document& doc2); + /** + * Calculates World Centroid distance + * @param doc1 + * @param doc2 + * @return distance between two documents + */ + double wcd(const emb_document& doc1, const emb_document& doc2); + + /** + * L2 norm squared of the difference between two word embeddings + * |a - b|2^2 + * @param a + * @param b + * @return distance between two word embeddings + */ + static double l2diff_norm(const util::array_view& a, + const util::array_view& b); + + /** + * Cosine measure between two word embeddings + * Since we want minimum between two similar terms it calculates (1 - cos)/2 + * @param a + * @param b + * @return distance between two word embeddings + */ + static double cosine(const util::array_view& a, + const util::array_view& b); + + private: + const size_t nthreads_; + std::shared_ptr, + double>> + cache_; + std::shared_ptr embeddings_; + const size_t dimension_; + const metric_type dist; + + std::unordered_map> + methods_; + + /** + * Returns distance between two terms using cache + * @param first_word_id first term id + * @param second_word_id second term id + * @return distance between two terms + */ + double f_c_distance(const size_t first_word_id, + const size_t second_word_id); +}; +} +} + +#endif // META_EMD_H diff --git a/include/meta/embeddings/word_embeddings.h b/include/meta/embeddings/word_embeddings.h index fac0ea535..3aab5d57b 100644 --- a/include/meta/embeddings/word_embeddings.h +++ b/include/meta/embeddings/word_embeddings.h @@ -63,12 +63,22 @@ class word_embeddings word_embeddings(std::istream& vocab, std::istream& first, std::istream& second); + + /** + * Loads word embeddings from txt file + * + * @param vectors The stream to read the vectors from + * @param num_lines Number of lines in the file + * @param dimension dimension of the embedding + */ + word_embeddings(std::istream& vectors, size_t num_lines, size_t dimension); + /** * @param term The term to look up * @return the embedding vector (as an array_view) for the given term, * or the vector for the unknown word as appropriate */ - embedding at(util::string_view term) const; + embedding at(std::string term) const; /** * @param tid The term id to look up @@ -76,6 +86,12 @@ class word_embeddings */ util::string_view term(std::size_t tid) const; + /** + * @param tid The term to look up + * @return the term id, or -1 if not found + */ + int64_t tid(std::string term) const; + /** * @param query A vector of the same length as a word embedding to * query for @@ -95,6 +111,13 @@ class word_embeddings */ const util::aligned_vector& vocab() const; + /** + * @param term term_id to look up + * @return the embedding vector (as an array_view) for the given term, + * or the vector for the unknown word as appropriate + */ + util::array_view at(std::size_t tid) const; + private: util::array_view vector(std::size_t tid); @@ -109,7 +132,9 @@ class word_embeddings util::aligned_vector id_to_term_; /// A hash table from a term to its id - hashing::probe_map term_to_id_; + hashing::probe_map term_to_id_; +// hashing::probe_map term_to_id_; + /// The embeddings matrix util::aligned_vector embeddings_; diff --git a/include/meta/index/ranker/wmd_base.h b/include/meta/index/ranker/wmd_base.h new file mode 100644 index 000000000..11a9e6711 --- /dev/null +++ b/include/meta/index/ranker/wmd_base.h @@ -0,0 +1,109 @@ +/** + * @file wmd_base.h + * @author lolik111 + */ + +#ifndef META_WMD_BASE_H +#define META_WMD_BASE_H + +#include "meta/embeddings/wmd/wm_distance.h" +#include "meta/embeddings/word_embeddings.h" +#include "meta/index/ranker/ranker.h" +#include "meta/index/ranker/ranker_factory.h" +#include "meta/util/array_view.h" +#include "meta/util/string_view.h" + +namespace meta +{ +namespace index +{ + +/** + * Implements word mover's distance model. + * + * @see http://mkusner.github.io/publications/WMD.pdf + * + * Required config parameters: + * ~~~toml + * [ranker] + * method = "wmd" + * ~~~ + * + * Optional config parameters: + * ~~~toml + * mode # current mode: can be "emd", "wcd", "rwmd", or + * "prefetch-prune" + * distance-func # type of the distance function: "l2diff" or "cosine" + * num-threads # number of threads used in the algorithm + * cache-per-thread # size of cache per each thread + * ~~~ + */ +class wmd_base : public ranker +{ + public: + /// Identifier for this ranker. + const static util::string_view id; + + const static std::string default_mode; + + const static std::string default_distance_func; + + const static constexpr size_t default_cache_size = 1000000; + + wmd_base(std::shared_ptr fwd, + std::shared_ptr embeddings, + size_t nthreads, size_t cache_size, std::string mode, + std::string distance_func); + + wmd_base(std::istream& in); + + void save(std::ostream& out) const override; + + std::vector + rank(ranker_context& ctx, uint64_t num_results, + const filter_function_type& filter) override; + + private: + std::shared_ptr fwd_; + std::shared_ptr embeddings_; + const size_t nthreads_; + const size_t cache_size_; + std::shared_ptr, + double>> + cache_; + const std::string mode_; + const std::string distance_func_; + /** + * Creates document, omitting terms not presenting in the embeddings + * @param tf vector of term frequences + * @return Struct representing one document in the wmd processing + */ + embeddings::emb_document + create_document(std::vector> tf); + + /** + * Calculates wmd based on the instance of the emd class and mode paralelly + * @param emd + * @param mode + * @param filter + * @param doc_to_compare + * @param docs documents + * @return vector of search results + */ + std::vector process(embeddings::wm_distance emd, + const std::string mode, + const filter_function_type& filter, + embeddings::emb_document doc_to_compare, + std::vector docs); +}; + +/** + * Specialization of the factory method used to create wmd + * rankers. + */ +template <> +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local); +} +} +#endif diff --git a/src/classify/classifier/knn.cpp b/src/classify/classifier/knn.cpp index f7b5ebb94..11e4dab86 100644 --- a/src/classify/classifier/knn.cpp +++ b/src/classify/classifier/knn.cpp @@ -159,7 +159,7 @@ std::unique_ptr make_multi_index_classifier( auto use_weighted = config.get_as("weighted").value_or(false); return make_unique(std::move(training), std::move(inv_idx), *k, - index::make_ranker(*ranker), use_weighted); + index::make_ranker(config, *ranker), use_weighted); } } } diff --git a/src/embeddings/CMakeLists.txt b/src/embeddings/CMakeLists.txt index 25441f3be..ac667ffe9 100644 --- a/src/embeddings/CMakeLists.txt +++ b/src/embeddings/CMakeLists.txt @@ -2,6 +2,7 @@ project(meta-embeddings) add_subdirectory(tools) add_subdirectory(analyzers) +add_subdirectory(wmd) add_library(meta-embeddings cooccurrence_counter.cpp word_embeddings.cpp) target_link_libraries(meta-embeddings cpptoml meta-analyzers meta-util meta-io) diff --git a/src/embeddings/tools/interactive_embeddings.cpp b/src/embeddings/tools/interactive_embeddings.cpp index d5c721287..2b467f0fc 100644 --- a/src/embeddings/tools/interactive_embeddings.cpp +++ b/src/embeddings/tools/interactive_embeddings.cpp @@ -45,7 +45,7 @@ parse_word(util::string_view& query, const embeddings::word_embeddings& glove) if (word.empty()) throw parse_exception{"invalid expression"}; parse_whitespace(query); - return glove.at(word).v; + return glove.at(word.to_string()).v; } std::vector parse_expression(util::string_view& query, diff --git a/src/embeddings/wmd/CMakeLists.txt b/src/embeddings/wmd/CMakeLists.txt new file mode 100644 index 000000000..306b35789 --- /dev/null +++ b/src/embeddings/wmd/CMakeLists.txt @@ -0,0 +1,8 @@ +project(meta-embeddings) + +add_library(meta-wmd wm_distance.cpp) +target_link_libraries(meta-wmd meta-embeddings) + +install(TARGETS meta-wmd + EXPORT meta-exports + DESTINATION lib) diff --git a/src/embeddings/wmd/wm_distance.cpp b/src/embeddings/wmd/wm_distance.cpp new file mode 100644 index 000000000..c41c9684f --- /dev/null +++ b/src/embeddings/wmd/wm_distance.cpp @@ -0,0 +1,215 @@ +/** + * @file wm_distance.cpp + * @author lolik111 + * + * All files in META are dual-licensed under the MIT and NCSA licenses. For more + * details, consult the file LICENSE.mit and LICENSE.ncsa in the root of the + * project. + */ + +//#include +//#include +//#include +//#include + +#include "meta/embeddings/wmd/wm_distance.h" +#include "meta/embeddings/wmd/min_cost_flow.h" +#include "meta/parallel/algorithm.h" + +namespace meta +{ + +namespace embeddings +{ + +wm_distance::wm_distance( + std::shared_ptr, + double>> + cache_, + std::shared_ptr embeddings, metric_type metric, + size_t nthreads /*= 1*/) + : nthreads_(nthreads), + cache_(cache_), + embeddings_(embeddings), + dimension_(embeddings->vector_size()), + dist(metric) +{ + methods_.emplace( + "rwmd", [this](const emb_document& doc1, const emb_document& doc2) { + auto score1 = this->emd_relaxed(doc1, doc2); + auto score2 = this->emd_relaxed(doc2, doc1); + return std::max(score1, score2); + }); + methods_.emplace( + "wcd", [this](const emb_document& doc1, const emb_document& doc2) { + return this->wcd(doc1, doc2); + }); + methods_.emplace( + "emd", [this](const emb_document& doc1, const emb_document& doc2) { + return this->emd(doc1, doc2); + }); +} + +double wm_distance::score(const std::string algorithm_type, + const emb_document& doc1, const emb_document& doc2) +{ + return methods_[algorithm_type](doc1, doc2); +} + +double wm_distance::emd(const emb_document& doc1, const emb_document& doc2) +{ + std::vector supply(doc1.n_terms + doc2.n_terms, 0); + std::vector demand(doc1.n_terms + doc2.n_terms, 0); + + for (size_t i = 0; i < doc1.n_terms; ++i) + { + supply[i] = doc1.weights[i]; + } + + for (size_t i = 0; i < doc2.n_terms; ++i) + { + demand[doc1.n_terms + i] = doc2.weights[i]; + } + + std::vector> cost( + supply.size(), std::vector(supply.size(), 0)); + + for (size_t i = 0; i < doc1.n_terms; ++i) + { + for (size_t j = 0; j < doc2.n_terms; ++j) + { + double dist = f_c_distance(doc1.ids[i], doc2.ids[j]); + assert(dist >= 0); + cost[i][j + doc1.n_terms] = dist; + cost[j + doc1.n_terms][i] = dist; + } + } + embeddings::min_cost_flow mcf; + auto score = mcf.emd_hat(supply, demand, cost); + + return score; +} + +double wm_distance::emd_relaxed(const emb_document& doc1, + const emb_document& doc2) +{ + std::vector ids(doc2.n_terms); + for (size_t i = 0; i < doc2.n_terms; i++) + { + ids[i] = i; + } + + double acc = 0; + for (size_t i = 0; i < doc1.n_terms; i++) + { + std::vector distance(doc2.n_terms); + for (size_t j = 0; j < doc2.n_terms; ++j) + { + distance[j] = f_c_distance(doc1.ids[i], doc2.ids[j]); + } + + if (doc1.weights[i] != 0) + { + std::sort(ids.begin(), ids.end(), + [&](const size_t a, const size_t b) { + bool ans; + ans = distance[a] < distance[b]; + return ans; + }); + + double remaining = doc1.weights[i]; + for (size_t j = 0; j < doc2.n_terms; j++) + { + uint64_t w = ids[j]; + if (remaining < doc2.weights[w]) + { + acc += remaining * distance[w]; + break; + } + else + { + remaining -= doc2.weights[w]; + acc += doc2.weights[w] * distance[w]; + } + } + } + } + return acc; +} + +double wm_distance::wcd(const emb_document& doc1, const emb_document& doc2) +{ + using namespace meta::math::operators; + + std::vector res1(dimension_, 0); + std::vector res2(dimension_, 0); + + auto start = doc1.ids.begin(); + for (auto w1 : doc1.weights) + { + res1 = res1 + embeddings_->at(*start++) * w1; + } + + start = doc2.ids.begin(); + for (auto w2 : doc2.weights) + { + res2 = res2 + embeddings_->at(*start++) * w2; + } + + return dist(res1, res2); +} + +double wm_distance::l2diff_norm(const util::array_view& a, + const util::array_view& b) +{ + double res = 0.0; + auto it1 = a.begin(); + auto it2 = b.begin(); + if (it1 == it2) + { + return 0; + } + + while (it1 != a.end()) + { + double val = *it1 - *it2; + res += val * val; + it1++; + it2++; + } + + return res; +} + +double wm_distance::cosine(const util::array_view& a, + const util::array_view& b) +{ + if (a.begin() == b.begin()) + return 0; + return (1.0 - std::inner_product(a.begin(), a.end(), b.begin(), 0.0)) / 2.0; +} + +double wm_distance::f_c_distance(const size_t first_word_id, + const size_t second_word_id) +{ + std::pair pair; + if (first_word_id < second_word_id) + { + pair = {first_word_id, second_word_id}; + } + else + { + pair = {second_word_id, first_word_id}; + } + + auto val = cache_->find(pair); + + return val.value_or([&]() { + auto dst = dist(embeddings_->at(first_word_id), + embeddings_->at(second_word_id)); + cache_->insert(pair, dst); + return dst; + }()); +} +} +} diff --git a/src/embeddings/word_embeddings.cpp b/src/embeddings/word_embeddings.cpp index e2013f98a..56436aa0a 100644 --- a/src/embeddings/word_embeddings.cpp +++ b/src/embeddings/word_embeddings.cpp @@ -21,6 +21,39 @@ namespace embeddings using vocab_type = hashing::probe_map; +word_embeddings::word_embeddings(std::istream& vectors, size_t num_lines, + size_t dimension) + : vector_size_{dimension}, + id_to_term_(num_lines), + term_to_id_{static_cast(std::ceil( + id_to_term_.size() / vocab_type::default_max_load_factor()))}, + embeddings_(vector_size_ * (id_to_term_.size() + 1)) +{ + printing::progress progress{" > Loading embeddings: ", id_to_term_.size()}; + + for (std::size_t tid = 0; tid < id_to_term_.size(); ++tid) + { + if (!vectors) + throw word_embeddings_exception{ + "embeddings stream ended unexpectedly"}; + + progress(tid); + + vectors >> id_to_term_[tid]; + term_to_id_[id_to_term_[tid]] = tid; + + auto vec = vector(tid); + std::generate(vec.begin(), vec.end(), [&]() { + double v; + vectors >> v; + return v; + }); + auto len = math::operators::l2norm(vec); + std::transform(vec.begin(), vec.end(), vec.begin(), + [=](double weight) { return weight / len; }); + } +} + word_embeddings::word_embeddings(std::istream& vocab, std::istream& vectors) : vector_size_{io::packed::read(vectors)}, id_to_term_(io::packed::read(vocab)), @@ -43,6 +76,7 @@ word_embeddings::word_embeddings(std::istream& vocab, std::istream& vectors) std::generate(vec.begin(), vec.end(), [&]() { return io::packed::read(vectors); }); } + } word_embeddings::word_embeddings(std::istream& vocab, std::istream& first, @@ -109,7 +143,7 @@ util::array_view word_embeddings::vector(std::size_t tid) const return {embeddings_.data() + tid * vector_size_, vector_size_}; } -embedding word_embeddings::at(util::string_view term) const +embedding word_embeddings::at(std::string term) const { std::size_t tid; auto v_it = term_to_id_.find(term); @@ -124,6 +158,27 @@ embedding word_embeddings::at(util::string_view term) const return {tid, vector(tid)}; } +util::array_view word_embeddings::at(std::size_t tid) const +{ + return vector(tid); +} + + +int64_t word_embeddings::tid(std::string term) const +{ + int64_t tid; + auto v_it = term_to_id_.find(term); + if (v_it == term_to_id_.end()) + { + tid = -1; + } + else + { + tid = v_it->value(); + } + return tid; +} + util::string_view word_embeddings::term(std::size_t tid) const { if (tid >= id_to_term_.size()) @@ -164,6 +219,7 @@ const util::aligned_vector& word_embeddings::vocab() const return id_to_term_; } + word_embeddings load_embeddings(const cpptoml::table& config) { auto prefix = config.get_as("prefix"); @@ -175,6 +231,30 @@ word_embeddings load_embeddings(const cpptoml::table& config) throw word_embeddings_exception{"embeddings directory does not exist: " + *prefix}; + auto mode = config.get_as("mode").value_or("average"); + + if (mode == "txt") + { + std::ifstream target{*prefix + "/embeddings.target.txt"}; + if (!target) + throw word_embeddings_exception{"missing target vectors in: " + + *prefix}; + auto lines = filesystem::num_lines(*prefix + "/embeddings.target.txt"); + auto dim = config.get_as("vector-size"); + if (!dim) + { + std::string line; + std::getline(target, line); + std::istringstream iss(line); + std::vector results( + (std::istream_iterator(iss)), + std::istream_iterator()); + dim = results.size() - 1; + target.seekg(0, target.beg); + } + return {target, lines, *dim}; + } + std::ifstream vocab{*prefix + "/vocab.bin", std::ios::binary}; if (!vocab) throw word_embeddings_exception{"missing vocabulary file in: " @@ -184,7 +264,6 @@ word_embeddings load_embeddings(const cpptoml::table& config) std::ifstream context{*prefix + "/embeddings.context.bin", std::ios::binary}; - auto mode = config.get_as("mode").value_or("average"); if (mode == "average") { if (!target) diff --git a/src/index/ranker/CMakeLists.txt b/src/index/ranker/CMakeLists.txt index 20518f751..43a3a5245 100644 --- a/src/index/ranker/CMakeLists.txt +++ b/src/index/ranker/CMakeLists.txt @@ -9,8 +9,9 @@ add_library(meta-ranker absolute_discount.cpp kl_divergence_prf.cpp rocchio.cpp ranker.cpp - ranker_factory.cpp) -target_link_libraries(meta-ranker meta-index) + ranker_factory.cpp + wmd_base.cpp) +target_link_libraries(meta-ranker meta-index meta-wmd) install(TARGETS meta-ranker EXPORT meta-exports diff --git a/src/index/ranker/ranker_factory.cpp b/src/index/ranker/ranker_factory.cpp index 86c1069af..9b0cb2d04 100644 --- a/src/index/ranker/ranker_factory.cpp +++ b/src/index/ranker/ranker_factory.cpp @@ -3,6 +3,7 @@ * @author Chase Geigle */ +#include #include "cpptoml.h" #include "meta/index/ranker/all.h" #include "meta/index/ranker/ranker_factory.h" @@ -31,6 +32,7 @@ ranker_factory::ranker_factory() reg(); reg(); reg(); + reg(); } std::unique_ptr make_ranker(const cpptoml::table& config) diff --git a/src/index/ranker/wmd_base.cpp b/src/index/ranker/wmd_base.cpp new file mode 100644 index 000000000..ce06b398f --- /dev/null +++ b/src/index/ranker/wmd_base.cpp @@ -0,0 +1,295 @@ +/** + * @file wmd_base.cpp + * @author lolik111 + */ + +#include "meta/index/ranker/wmd_base.h" +#include "meta/index/forward_index.h" +#include "meta/index/postings_data.h" +#include "meta/parallel/parallel_for.h" +#include "meta/util/fixed_heap.h" + +namespace meta +{ +namespace index +{ +const util::string_view wmd_base::id = "wmd-base"; + +const std::string wmd_base::default_mode = "rwmd"; + +const std::string wmd_base::default_distance_func = "cosine"; + +const constexpr size_t wmd_base::default_cache_size; + +wmd_base::wmd_base(std::shared_ptr fwd, + std::shared_ptr embeddings, + size_t nthreads, size_t cache_size, std::string mode, + std::string distance_func) + : fwd_(fwd), + embeddings_(embeddings), + nthreads_(nthreads), + cache_size_(cache_size), + cache_{std::make_shared, + double>>(nthreads, + cache_size)}, + mode_(mode), + distance_func_(distance_func) +{ +} + +void wmd_base::save(std::ostream& out) const +{ + io::packed::write(out, id); + io::packed::write(out, nthreads_); + io::packed::write(out, cache_size_); + io::packed::write(out, mode_); + io::packed::write(out, distance_func_); + io::packed::write(out, fwd_->index_name()); +} + +wmd_base::wmd_base(std::istream& in) + : nthreads_{io::packed::read(in)}, + cache_size_{io::packed::read(in)}, + cache_{std::make_shared, + double>>(nthreads_, + cache_size_)}, + mode_{io::packed::read(in)}, + distance_func_{io::packed::read(in)} +{ + auto path = io::packed::read(in); + auto cfg = cpptoml::parse_file(path + "/config.toml"); + fwd_ = make_index(*cfg); + + embeddings_ = std::make_shared( + embeddings::load_embeddings(*cfg)); +} + +std::vector wmd_base::rank(ranker_context& ctx, + uint64_t num_results, + const filter_function_type& filter) +{ + auto results = util::make_fixed_heap( + num_results, [](const search_result& a, const search_result& b) { + return a.score < b.score; + }); + + embeddings::wm_distance::metric_type distance; + if (distance_func_ == "cosine") + { + distance = embeddings::wm_distance::cosine; + } + else if (distance_func_ == "l2diff") + { + distance = embeddings::wm_distance::l2diff_norm; + } + else + { + distance = embeddings::wm_distance::cosine; + } + std::vector> tf_pc; + tf_pc.reserve(ctx.postings.size()); + for (auto one : ctx.postings) + { + tf_pc.push_back({one.t_id, one.query_term_weight}); + } + + auto doc_to_compare = create_document(tf_pc); + if (doc_to_compare.n_terms == 0) + { + return results.extract_top(); // empty + } + + parallel::thread_pool pool(nthreads_); + std::vector docs = fwd_->docs(); + + if (mode_ != "prefetch-prune") + { + embeddings::wm_distance emd(cache_, embeddings_, distance); + auto scores = process(emd, mode_, filter, doc_to_compare, fwd_->docs()); + for (auto score : scores) + { + results.emplace(score); + } + } + else + { + embeddings::wm_distance emd(cache_, embeddings_, distance); + + // wcd phase + auto scores = process( + {cache_, embeddings_, embeddings::wm_distance::l2diff_norm}, "wcd", + filter, doc_to_compare, fwd_->docs()); + + std::sort(scores.begin(), scores.end(), + [&](const search_result a, const search_result b) { + bool ans; + ans = a.score < b.score; + return ans; + }); + + std::vector k_docs; + for (size_t i = 0; i < num_results; i++) + { + k_docs.push_back(scores[i].d_id); + } + scores.erase(scores.begin(), scores.begin() + num_results); + // emd after wcd + auto k_emd = process(emd, "emd", filter, doc_to_compare, k_docs); + for (auto sr : k_emd) + { + results.emplace(sr); + } + + // worst result + auto last = (--results.end())->score; + + // how much documents compare using with rwmd + const size_t magic_constant + = std::max(static_cast(fwd_->docs().size() / 8), + static_cast(num_results * 8)); + + std::vector rwmd_docs(magic_constant); + auto start = scores.begin(); + std::generate(rwmd_docs.begin(), rwmd_docs.end(), + [&]() { return (*start++).d_id; }); + // rwmd phase + auto rwmd_results + = process(emd, "rwmd", filter, doc_to_compare, rwmd_docs); + + std::vector pretend_docs; + + for (auto sr : rwmd_results) + { + if (sr.score < last) + { + pretend_docs.emplace_back(sr.d_id); + } + } + + if (!pretend_docs.empty()) + { // emd phase + auto pretend_results + = process(emd, "emd", filter, doc_to_compare, pretend_docs); + for (auto sr : pretend_results) + { + results.emplace(sr); + } + } + } + + return results.extract_top(); +} + +std::vector +wmd_base::process(embeddings::wm_distance emd, const std::string mode, + const filter_function_type& filter, + embeddings::emb_document doc_to_compare, + std::vector docs) +{ + parallel::thread_pool pool(nthreads_); + + auto scores = parallel::for_each_block( + docs.begin(), docs.end(), pool, [&](std::vector::iterator start, + std::vector::iterator end) { + std::vector block_scores; + for (auto it = start; it != end; ++it) + { + if (!filter(*it)) + continue; + + auto doc = create_document(fwd_->search_primary(*it)->counts()); + + if (doc.n_terms == 0) + { + continue; + } + auto score + = static_cast(emd.score(mode, doc, doc_to_compare)); + block_scores.emplace_back(*it, score); + } + return block_scores; + }); + + std::vector results; + results.reserve(fwd_->docs().size()); + for (auto& vec : scores) + { + for (auto sr : vec.get()) + { + results.emplace_back(sr); + } + } + return results; +} + +embeddings::emb_document +wmd_base::create_document(std::vector> tf) +{ + size_t unique_terms_count = tf.size(); + size_t all_terms_count = 0; + + embeddings::emb_document document; + document.ids = std::vector(); + document.ids.reserve(unique_terms_count); + document.weights = std::vector(); + document.weights.reserve(unique_terms_count); + + for (auto term_data : tf) + { + std::string term = fwd_->term_text(term_data.first); + auto vec_id = this->embeddings_->tid(term); + + if (vec_id >= 0) + { + all_terms_count += term_data.second; + document.weights.emplace_back(term_data.second); + document.ids.emplace_back(vec_id); + } + else + { + unique_terms_count--; + } + } + + using namespace meta::math::operators; + + document.weights = document.weights / all_terms_count; + document.n_terms = unique_terms_count; + + return document; +} + +template <> +std::unique_ptr make_ranker(const cpptoml::table& global, + const cpptoml::table& local) +{ + if (global.begin() == global.end()) + throw ranker_exception{"empty global configuration provided to " + "construction of wmd_base ranker"}; + auto f_idx = make_index(global); + + auto embeddings = global.get_table("embeddings"); + if (!embeddings) + throw std::runtime_error{"\"embeddings\" group needed in config file!"}; + + auto glove = embeddings::load_embeddings(*embeddings); + + auto cache_size = local.get_as("cache-per-thread") + .value_or(wmd_base::default_cache_size); + size_t nthreads = local.get_as("num-threads") + .value_or(std::thread::hardware_concurrency()); + + auto mode + = local.get_as("mode").value_or(wmd_base::default_mode); + + auto distance_func = local.get_as("distance-func") + .value_or(wmd_base::default_distance_func); + + return make_unique( + f_idx, std::make_shared(glove), nthreads, + cache_size, mode, distance_func); +} +} +}