diff --git a/src/include/thread/itasksys.hpp b/src/include/thread/itasksys.hpp new file mode 100644 index 0000000..8bfd0bc --- /dev/null +++ b/src/include/thread/itasksys.hpp @@ -0,0 +1,29 @@ +#ifndef _ITASKSYS_H +#define _ITASKSYS_H +#include "common/Status.hpp" +#include +#pragma once + +using TaskID = int; +using DB::Status; + +class IRunnable { +public: + virtual ~IRunnable(); + + virtual void runTask(int task_id, int num_total_tasks) = 0; +}; + +class ITaskSystem { +public: + ITaskSystem(int num_threads); + virtual ~ITaskSystem(); + + virtual void run(IRunnable *runnable, int num_total_tasks) = 0; + + virtual TaskID runAsyncWithDeps(IRunnable *runnable, int num_total_tasks, + const std::vector &deps) = 0; + + virtual void sync() = 0; +}; +#endif diff --git a/src/include/thread/tasksys.hpp b/src/include/thread/tasksys.hpp new file mode 100644 index 0000000..d21dd19 --- /dev/null +++ b/src/include/thread/tasksys.hpp @@ -0,0 +1,82 @@ +#ifndef _TASKSYS_H +#define _TASKSYS_H +#pragma once + +#include "common/Status.hpp" +#include "itasksys.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +class TaskGroup { +public: + int groupId{}; + int total_num_tasks; + IRunnable *runnable; + std::atomic taskRemain; + std::set depending; + + TaskGroup(int groupId, IRunnable *runnable, int numTotalTasks, + const std::vector &deps) { + this->groupId = groupId; + this->runnable = runnable; + this->total_num_tasks = numTotalTasks; + this->taskRemain = numTotalTasks; + this->depending = {}; + for (auto dep : deps) { + this->depending.insert(dep); + } + } + friend bool operator<(const TaskGroup &a, const TaskGroup &b) { + return a.depending.size() > b.depending.size(); + } +}; + +struct RunnableTask { +public: + TaskGroup *taskGroup; + int id; + RunnableTask(TaskGroup *taskGroup_, int id_) + : id(id_), taskGroup(taskGroup_) {} +}; + +class TaskSystemParallelThreadPoolSleeping : public ITaskSystem { +public: + TaskSystemParallelThreadPoolSleeping(int num_threads); + + ~TaskSystemParallelThreadPoolSleeping(); + + template + auto run(Func &&func, Args &&...args) + -> std::future::type>; + +private: + void run(IRunnable *runnable, int num_total_tasks); + + TaskID runAsyncWithDeps(IRunnable *runnable, int num_total_tasks, + const std::vector &deps); + + void sync(); + + void func(); + + std::vector threads; + std::queue taskQueue; + std::set taskGroupSet; + std::priority_queue taskGroupQueue; + std::atomic taskRemained; + bool exitFlag; + bool finishFlag; + int numGroup; + std::mutex cntMtx; + std::mutex queueMtx; + std::condition_variable countCond; + std::condition_variable queueCond; +}; + +#endif diff --git a/src/thread/tasksys.cpp b/src/thread/tasksys.cpp new file mode 100644 index 0000000..d526548 --- /dev/null +++ b/src/thread/tasksys.cpp @@ -0,0 +1,144 @@ +#include "thread/tasksys.hpp" +#include "common/Status.hpp" +#include "common/ZeitgeistDB.hpp" +#include +#include +#include +#include + +IRunnable::~IRunnable() {} + +ITaskSystem::ITaskSystem(int num_threads) {} +ITaskSystem::~ITaskSystem() {} + +class FunctionRunnable : public IRunnable { +public: + FunctionRunnable(std::function func) : func_(func) {} + + void runTask(int task_id, int num_total_tasks) override { func_(); } + +private: + std::function func_; +}; + +TaskSystemParallelThreadPoolSleeping::TaskSystemParallelThreadPoolSleeping( + int num_threads) + : ITaskSystem(num_threads) { + exitFlag = false; + numGroup = 0; + for (int i = 0; i < num_threads; i++) { + threads.emplace_back(&TaskSystemParallelThreadPoolSleeping::func, this); + } +} + +TaskSystemParallelThreadPoolSleeping::~TaskSystemParallelThreadPoolSleeping() { + exitFlag = true; + queueCond.notify_all(); + for (auto &thread : threads) { + thread.join(); + } +} + +void TaskSystemParallelThreadPoolSleeping::run(IRunnable *runnable, + int num_total_tasks) { + runAsyncWithDeps(runnable, num_total_tasks, {}); + sync(); +} + +void TaskSystemParallelThreadPoolSleeping::func() { + RunnableTask *task; + TaskGroup *task_group; + while (true) { + queueCond.notify_all(); + while (true) { + std::unique_lock lock(queueMtx); + queueCond.wait(lock, [this] { return exitFlag || !taskQueue.empty(); }); + if (exitFlag) { + return; + } + if (taskQueue.empty()) { + continue; + } + task = taskQueue.front(); + taskQueue.pop(); + break; + } + task_group = task->taskGroup; + task_group->runnable->runTask(task->id, task_group->total_num_tasks); + task_group->taskRemain--; + if (task_group->taskRemain <= 0) { + for (auto &task : taskGroupSet) { + task_group->depending.erase(task->groupId); + } + countCond.notify_one(); + } else { + queueCond.notify_all(); + } + } +} + +TaskID TaskSystemParallelThreadPoolSleeping::runAsyncWithDeps( + IRunnable *runnable, int num_total_tasks, const std::vector &deps) { + auto new_task_group = + new TaskGroup(numGroup, runnable, num_total_tasks, deps); + taskGroupQueue.push(new_task_group); + taskGroupSet.insert(new_task_group); + + return numGroup++; +} + +void TaskSystemParallelThreadPoolSleeping::sync() { + TaskGroup *task_group; + RunnableTask *runnable_group; + + while (!taskGroupQueue.empty()) { + task_group = taskGroupQueue.top(); + if (!task_group->depending.empty()) { + continue; + } + queueMtx.lock(); + for (int i = 0; i < task_group->total_num_tasks; i++) { + runnable_group = new RunnableTask(task_group, i); + taskQueue.push(runnable_group); + } + queueMtx.unlock(); + queueCond.notify_all(); + taskGroupQueue.pop(); + } + while (true) { + std::unique_lock lock(cntMtx); + countCond.wait(lock, [this] { return finishFlag; }); + finishFlag = true; + for (auto &task_group : taskGroupSet) { + if (task_group->taskRemain > 0) { + finishFlag = false; + break; + } + } + if (finishFlag) { + return; + } + } + return; +} + +template +auto TaskSystemParallelThreadPoolSleeping::run(Func &&func, Args &&...args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared>( + std::bind(std::forward(func), std::forward(args)...)); + + std::future result = task->get_future(); + + auto wrapped_func = [task]() { (*task)(); }; + + std::shared_ptr runnable = + std::make_shared(wrapped_func); + + runAsyncWithDeps(runnable.get(), 1, {}); + sync(); + + return result; +} \ No newline at end of file