Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 68 additions & 38 deletions include/ParallelPriotityQueue/SpapQueueWorker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ class WorkerResource {

private:
const std::array<std::size_t, tables::maxTableSize<GlobalQType::netw_>()>
channelIndices_; ///< Order of outgoing
///< channels to push to.
std::array<value_type, GlobalQType::netw_.maxBatchSize()> outBuffer_; ///< Small buffer before
///< pushing to outgoing
///< channel.
channelIndices_; ///< Order of outgoing
///< channels to push to.
std::array<value_type, 2U * GlobalQType::netw_.maxBatchSize()> outBuffer_; ///< Small buffer before
///< pushing to outgoing
///< channel.

const std::size_t workerId_; ///< Worker Id in the global queue.
std::size_t localCount_{0U}; ///< A partial account of the number of tasks in the global queue.
GlobalQType &globalQueue_; ///< Reference to the global queue.
typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::iterator
bufferPointer_; ///< Pointer to the next free spot in the outBuffer_.
std::size_t bufferHead_{0U}; ///< Head of out ring buffer.
std::size_t bufferTail_{0U}; ///< Tail of out ring buffer.
typename std::array<std::size_t, tables::maxTableSize<GlobalQType::netw_>()>::const_iterator
channelPointer_; ///< Pointer to the next outgoing channel.
const typename std::array<std::size_t, tables::maxTableSize<GlobalQType::netw_>()>::const_iterator
Expand All @@ -75,8 +75,7 @@ class WorkerResource {
inline void decrGlobalCount() noexcept;

[[nodiscard("Push may fail when channel is full.\n")]] inline bool pushOutBuffer() noexcept;
inline void pushOutBufferSelf(
const typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::iterator fromPointer) noexcept;
inline void pushOutBufferSelf(const std::size_t numElements) noexcept;

inline void enqueueInChannels() noexcept;
virtual void processElement(const value_type val) noexcept = 0;
Expand Down Expand Up @@ -121,7 +120,10 @@ class WorkerResource {
* @see SpapQueue
* @see WorkerResource
*/
template <template <class, BasicQueue, std::size_t> class WorkerTemplate, class GlobalQType, BasicQueue LocalQType, std::size_t N>
template <template <class, BasicQueue, std::size_t> class WorkerTemplate,
class GlobalQType,
BasicQueue LocalQType,
std::size_t N>
consteval bool isDerivedWorkerResource() {
static_assert(N <= GlobalQType::netw_.numWorkers_);

Expand All @@ -143,7 +145,10 @@ consteval bool isDerivedWorkerResource() {
* @tparam LocalQType Worker local queue type.
* @tparam N Tuple of first N workers.
*/
template <template <class, BasicQueue, std::size_t> class WorkerTemplate, class GlobalQType, class LocalQType, std::size_t N>
template <template <class, BasicQueue, std::size_t> class WorkerTemplate,
class GlobalQType,
class LocalQType,
std::size_t N>
struct WorkerCollectiveHelper {
static_assert(N <= GlobalQType::netw_.numWorkers_);
template <typename... Args>
Expand All @@ -170,7 +175,6 @@ constexpr WorkerResource<GlobalQType, LocalQType, numPorts>::WorkerResource(
tables::extendTable<tables::maxTableSize<GlobalQType::netw_>(), channelIndicesLength>(channelIndices)),
workerId_(workerId),
globalQueue_(globalQueue),
bufferPointer_(outBuffer_.begin()),
channelPointer_(channelIndices_.cbegin()),
channelTableEndPointer_(std::next(channelIndices_.cbegin(), channelIndicesLength)),
queue_(std::forward<Args>(localQargs)...) { }
Expand All @@ -196,22 +200,21 @@ inline bool WorkerResource<GlobalQType, LocalQType, numPorts>::push(InputIt firs
*/
template <typename GlobalQType, BasicQueue LocalQType, std::size_t numPorts>
inline void WorkerResource<GlobalQType, LocalQType, numPorts>::enqueueGlobal(const value_type val) noexcept {
assert(bufferPointer_ != outBuffer_.end());
assert(bufferTail_ <= bufferHead_);
assert(bufferHead_ < bufferTail_ + outBuffer_.size());

incrGlobalCount();
*bufferPointer_ = val;
++bufferPointer_;
outBuffer_[bufferHead_ % outBuffer_.size()] = val;
++bufferHead_;

std::size_t maxAttempts = GlobalQType::netw_.maxPushAttempts_;
while (static_cast<std::size_t>(std::distance(outBuffer_.begin(), bufferPointer_))
>= GlobalQType::netw_.batchSize_[*channelPointer_]
&& maxAttempts > 0U) {
while (bufferHead_ - bufferTail_ >= GlobalQType::netw_.batchSize_[*channelPointer_] && maxAttempts > 0U) {
if (not pushOutBuffer()) { --maxAttempts; }

++channelPointer_;
if (channelPointer_ == channelTableEndPointer_) { channelPointer_ = channelIndices_.cbegin(); }
}
if (maxAttempts == 0U) [[unlikely]] { pushOutBufferSelf(outBuffer_.begin()); }
if (maxAttempts == 0U) [[unlikely]] { pushOutBufferSelf(bufferHead_ - bufferTail_); }
}

/**
Expand All @@ -223,20 +226,35 @@ inline bool WorkerResource<GlobalQType, LocalQType, numPorts>::pushOutBuffer() n
bool successfulPush;

const std::size_t batch = GlobalQType::netw_.batchSize_[*channelPointer_];
assert(batch <= static_cast<std::size_t>(std::distance(outBuffer_.begin(), bufferPointer_)));
const typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::iterator itBegin = std::prev(
bufferPointer_,
static_cast<typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::difference_type>(
batch));
assert(batch <= bufferHead_ - bufferTail_);

const std::size_t targetWorker = GlobalQType::netw_.edgeTargets_[*channelPointer_];
if (targetWorker == GlobalQType::netw_.numWorkers_) { // netw.numWorkers_ is reserved for self-push
pushOutBufferSelf(itBegin);
pushOutBufferSelf(batch);
successfulPush = true;
} else {
const std::size_t reducedTail = bufferTail_ % outBuffer_.size();
const std::size_t numElementsFirstPush = std::min(outBuffer_.size() - reducedTail, batch);
const std::size_t numElementsSecondPush = batch - numElementsFirstPush;

const auto itBeginFirst = std::next(
outBuffer_.begin(), static_cast<typename decltype(outBuffer_)::difference_type>(reducedTail));
const auto itEndFirst = std::next(
itBeginFirst, static_cast<typename decltype(outBuffer_)::difference_type>(numElementsFirstPush));
const auto itEndSecond
= std::next(outBuffer_.begin(),
static_cast<typename decltype(outBuffer_)::difference_type>(numElementsSecondPush));

const std::size_t port = GlobalQType::netw_.targetPort_[*channelPointer_];
successfulPush = globalQueue_.pushInternal(itBegin, bufferPointer_, targetWorker, port);
if (successfulPush) { bufferPointer_ = itBegin; }
successfulPush = globalQueue_.pushInternal(itBeginFirst, itEndFirst, targetWorker, port);
if (successfulPush) { bufferTail_ += numElementsFirstPush; }

if (numElementsSecondPush > 0U) {
const bool successfulSecondPush
= globalQueue_.pushInternal(outBuffer_.begin(), itEndSecond, targetWorker, port);
successfulPush |= successfulSecondPush;
if (successfulSecondPush) { bufferTail_ += numElementsSecondPush; };
}
}

return successfulPush;
Expand All @@ -245,25 +263,37 @@ inline bool WorkerResource<GlobalQType, LocalQType, numPorts>::pushOutBuffer() n
/**
* @brief Pushes all task from (including) fromPointer in the outbuffer to the local queue.
*
* @param fromPointer
* @param numElements
*/
template <typename GlobalQType, BasicQueue LocalQType, std::size_t numPorts>
inline void WorkerResource<GlobalQType, LocalQType, numPorts>::pushOutBufferSelf(
const typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::iterator fromPointer) noexcept {
const std::size_t numElements) noexcept {
constexpr bool hasBatchPush
= requires (LocalQType &q,
typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::iterator first,
typename std::array<value_type, GlobalQType::netw_.maxBatchSize()>::iterator last) {
q.push(first, last);
};
typename decltype(outBuffer_)::iterator first,
typename decltype(outBuffer_)::iterator last) { q.push(first, last); };

const std::size_t reducedTail = bufferTail_ % outBuffer_.size();
const std::size_t numElementsFirstPush = std::min(outBuffer_.size() - reducedTail, numElements);
const std::size_t numElementsSecondPush = numElements - numElementsFirstPush;

const auto itBeginFirst = std::next(
outBuffer_.begin(), static_cast<typename decltype(outBuffer_)::difference_type>(reducedTail));
const auto itEndFirst = std::next(
itBeginFirst, static_cast<typename decltype(outBuffer_)::difference_type>(numElementsFirstPush));
const auto itEndSecond
= std::next(outBuffer_.begin(),
static_cast<typename decltype(outBuffer_)::difference_type>(numElementsSecondPush));

if constexpr (hasBatchPush) {
auto it = fromPointer;
queue_.push(it, bufferPointer_);
queue_.push(itBeginFirst, itEndFirst);
queue_.push(outBuffer_.begin(), itEndSecond);
} else {
for (auto it = fromPointer; it != bufferPointer_; ++it) { queue_.push(*it); }
for (auto it = itBeginFirst; it != itEndFirst; ++it) { queue_.push(*it); }
for (auto it = outBuffer_.begin(); it != itEndSecond; ++it) { queue_.push(*it); }
}
bufferPointer_ = fromPointer;

bufferTail_ += numElements;
}

/**
Expand Down Expand Up @@ -306,7 +336,7 @@ inline void WorkerResource<GlobalQType, LocalQType, numPorts>::run(std::stop_tok
++cntr;
}
enqueueInChannels();
pushOutBufferSelf(outBuffer_.begin());
pushOutBufferSelf(bufferHead_ - bufferTail_);
}
}

Expand Down