Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/include/thread/itasksys.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef _ITASKSYS_H
#define _ITASKSYS_H
#include "common/Status.hpp"
#include <vector>
#pragma once

using TaskID = int;
using DB::Status;

class IRunnable {
public:
virtual ~IRunnable();

virtual void runTask(int task_id, int num_total_tasks) = 0;
};

class ITaskSystem {
public:
ITaskSystem(int num_threads);
virtual ~ITaskSystem();

virtual void run(IRunnable *runnable, int num_total_tasks) = 0;

virtual TaskID runAsyncWithDeps(IRunnable *runnable, int num_total_tasks,
const std::vector<TaskID> &deps) = 0;

virtual void sync() = 0;
};
#endif
82 changes: 82 additions & 0 deletions src/include/thread/tasksys.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef _TASKSYS_H
#define _TASKSYS_H
#pragma once

#include "common/Status.hpp"
#include "itasksys.hpp"
#include <atomic>
#include <condition_variable>
#include <functional>
#include <future>
#include <mutex>
#include <queue>
#include <set>
#include <thread>

class TaskGroup {
public:
int groupId{};
int total_num_tasks;
IRunnable *runnable;
std::atomic<int> taskRemain;
std::set<TaskID> depending;

TaskGroup(int groupId, IRunnable *runnable, int numTotalTasks,
const std::vector<TaskID> &deps) {
this->groupId = groupId;
this->runnable = runnable;
this->total_num_tasks = numTotalTasks;
this->taskRemain = numTotalTasks;
this->depending = {};
for (auto dep : deps) {
this->depending.insert(dep);
}
}
friend bool operator<(const TaskGroup &a, const TaskGroup &b) {
return a.depending.size() > b.depending.size();
}
};

struct RunnableTask {
public:
TaskGroup *taskGroup;
int id;
RunnableTask(TaskGroup *taskGroup_, int id_)
: id(id_), taskGroup(taskGroup_) {}
};

class TaskSystemParallelThreadPoolSleeping : public ITaskSystem {
public:
TaskSystemParallelThreadPoolSleeping(int num_threads);

~TaskSystemParallelThreadPoolSleeping();

template <typename Func, typename... Args>
auto run(Func &&func, Args &&...args)
-> std::future<typename std::invoke_result<Func, Args...>::type>;

private:
void run(IRunnable *runnable, int num_total_tasks);

TaskID runAsyncWithDeps(IRunnable *runnable, int num_total_tasks,
const std::vector<TaskID> &deps);

void sync();

void func();

std::vector<std::thread> threads;
std::queue<RunnableTask *> taskQueue;
std::set<TaskGroup *> taskGroupSet;
std::priority_queue<TaskGroup *> taskGroupQueue;
std::atomic<int> taskRemained;
bool exitFlag;
bool finishFlag;
int numGroup;
std::mutex cntMtx;
std::mutex queueMtx;
std::condition_variable countCond;
std::condition_variable queueCond;
};

#endif
144 changes: 144 additions & 0 deletions src/thread/tasksys.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#include "thread/tasksys.hpp"
#include "common/Status.hpp"
#include "common/ZeitgeistDB.hpp"
#include <functional>
#include <memory>
#include <mutex>
#include <utility>

IRunnable::~IRunnable() {}

ITaskSystem::ITaskSystem(int num_threads) {}
ITaskSystem::~ITaskSystem() {}

class FunctionRunnable : public IRunnable {
public:
FunctionRunnable(std::function<void()> func) : func_(func) {}

void runTask(int task_id, int num_total_tasks) override { func_(); }

private:
std::function<void()> func_;
};

TaskSystemParallelThreadPoolSleeping::TaskSystemParallelThreadPoolSleeping(
int num_threads)
: ITaskSystem(num_threads) {
exitFlag = false;
numGroup = 0;
for (int i = 0; i < num_threads; i++) {
threads.emplace_back(&TaskSystemParallelThreadPoolSleeping::func, this);
}
}

TaskSystemParallelThreadPoolSleeping::~TaskSystemParallelThreadPoolSleeping() {
exitFlag = true;
queueCond.notify_all();
for (auto &thread : threads) {
thread.join();
}
}

void TaskSystemParallelThreadPoolSleeping::run(IRunnable *runnable,
int num_total_tasks) {
runAsyncWithDeps(runnable, num_total_tasks, {});
sync();
}

void TaskSystemParallelThreadPoolSleeping::func() {
RunnableTask *task;
TaskGroup *task_group;
while (true) {
queueCond.notify_all();
while (true) {
std::unique_lock<std::mutex> lock(queueMtx);
queueCond.wait(lock, [this] { return exitFlag || !taskQueue.empty(); });
if (exitFlag) {
return;
}
if (taskQueue.empty()) {
continue;
}
task = taskQueue.front();
taskQueue.pop();
break;
}
task_group = task->taskGroup;
task_group->runnable->runTask(task->id, task_group->total_num_tasks);
task_group->taskRemain--;
if (task_group->taskRemain <= 0) {
for (auto &task : taskGroupSet) {
task_group->depending.erase(task->groupId);
}
countCond.notify_one();
} else {
queueCond.notify_all();
}
}
}

TaskID TaskSystemParallelThreadPoolSleeping::runAsyncWithDeps(
IRunnable *runnable, int num_total_tasks, const std::vector<TaskID> &deps) {
auto new_task_group =
new TaskGroup(numGroup, runnable, num_total_tasks, deps);
taskGroupQueue.push(new_task_group);
taskGroupSet.insert(new_task_group);

return numGroup++;
}

void TaskSystemParallelThreadPoolSleeping::sync() {
TaskGroup *task_group;
RunnableTask *runnable_group;

while (!taskGroupQueue.empty()) {
task_group = taskGroupQueue.top();
if (!task_group->depending.empty()) {
continue;
}
queueMtx.lock();
for (int i = 0; i < task_group->total_num_tasks; i++) {
runnable_group = new RunnableTask(task_group, i);
taskQueue.push(runnable_group);
}
queueMtx.unlock();
queueCond.notify_all();
taskGroupQueue.pop();
}
while (true) {
std::unique_lock<std::mutex> lock(cntMtx);
countCond.wait(lock, [this] { return finishFlag; });
finishFlag = true;
for (auto &task_group : taskGroupSet) {
if (task_group->taskRemain > 0) {
finishFlag = false;
break;
}
}
if (finishFlag) {
return;
}
}
return;
}

template <typename Func, typename... Args>
auto TaskSystemParallelThreadPoolSleeping::run(Func &&func, Args &&...args)
-> std::future<typename std::invoke_result<Func, Args...>::type> {
using return_type = typename std::result_of<Func(Args...)>::type;

auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<Func>(func), std::forward<Args>(args)...));

std::future<return_type> result = task->get_future();

auto wrapped_func = [task]() { (*task)(); };

std::shared_ptr<FunctionRunnable> runnable =
std::make_shared<FunctionRunnable>(wrapped_func);

runAsyncWithDeps(runnable.get(), 1, {});
sync();

return result;
}