Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/thread_pool/queue_props.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <algorithm>

namespace dp {

struct queue_props final{
enum class overflow_action { DISCARD_OLDEST, DISCARD_NEWEST };
private:
size_t queue_size_;
overflow_action overflow_action_;
bool notify_;
public:
queue_props():
queue_size_{0},
overflow_action_{overflow_action::DISCARD_NEWEST},
notify_{false}
{}
queue_props(size_t queueSize, bool notify=false):
queue_size_{queueSize},
overflow_action_{overflow_action::DISCARD_NEWEST},
notify_{notify}
{}
queue_props(size_t queueSize, overflow_action overflowAction, bool notify=false):
queue_size_{queueSize},
overflow_action_{overflowAction},
notify_{notify}
{}

bool will_notify() const { return notify_; }
bool is_infinite() const { return queue_size_ == 0; }
bool no_queue() const { return queue_size_ == 1; }
size_t get_queue_size() const { return queue_size_; }
overflow_action get_overflow_action() const { return overflow_action_; }
size_t num_threads(size_t requested) const { return overflow_action_ == overflow_action::DISCARD_OLDEST ? std::min((size_t)1,requested) : requested; }
};

} // namespace dp
46 changes: 42 additions & 4 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#endif

#include "thread_safe_queue.h"
#include "queue_props.h"

namespace dp {
namespace details {
Expand Down Expand Up @@ -114,13 +115,38 @@ namespace dp {
}
}

~thread_pool() {
wait_for_tasks();

/**
* @brief stops the thread pool with a non-blocking request.
* @details This does not clear queued tasks. But it does prevent
* new tasks from being generated queued.
*/
void stop_non_blocking() {
// if stopped_ return, otherwise set to true
bool FALSE = false;
if(!stopped_.compare_exchange_strong(FALSE, true)) return;
// stop all threads
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i].request_stop();
tasks_[i].signal.release();
}
}

/**
* @brief stops the thread pool and waits for all queued tasks to complete.
* @details This does not clear queued tasks. But it does prevent
* new tasks from being generated queued.
*/
void stop() {
stop_non_blocking();
wait_for_tasks();
}

/**
* @brief Destroy the thread pool object, calling stop() in the process.
*/
~thread_pool() {
stop();
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i].join();
}
}
Expand All @@ -143,6 +169,7 @@ namespace dp {
typename ReturnType = std::invoke_result_t<Function&&, Args&&...>>
requires std::invocable<Function, Args...>
[[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args) {
if(stopped_.load()) throw std::runtime_error("Attempted to enqueue a new task to a stopped thread pool");
#ifdef __cpp_lib_move_only_function
// we can do this in C++23 because we now have support for move only functions
std::promise<ReturnType> promise;
Expand Down Expand Up @@ -208,6 +235,7 @@ namespace dp {
template <typename Function, typename... Args>
requires std::invocable<Function, Args...>
void enqueue_detach(Function&& func, Args&&... args) {
if(stopped_.load()) throw std::runtime_error("Attempted to enqueue a new task to a stopped thread pool");
enqueue_task(
std::move([f = std::forward<Function>(func),
... largs = std::forward<Args>(args)]() mutable -> decltype(auto) {
Expand Down Expand Up @@ -262,6 +290,15 @@ namespace dp {
return removed_task_count;
}

/**
* @brief Check if this threadpool has been requested to stop
*
* @return true if stopped, false otherwise
*/
bool stop_requested() const {
return stopped_;
}

private:
template <typename Function>
void enqueue_task(Function&& f) {
Expand Down Expand Up @@ -291,13 +328,14 @@ namespace dp {
dp::thread_safe_queue<FunctionType> tasks{};
std::binary_semaphore signal{0};
};

std::vector<ThreadType> threads_;
std::deque<task_item> tasks_;
dp::thread_safe_queue<std::size_t> priority_queue_;
// guarantee these get zero-initialized
std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0};
std::atomic_bool threads_complete_signal_{false};
std::atomic_bool stopped_{false};
};

/**
Expand Down