-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNetRunner.cpp
More file actions
72 lines (66 loc) · 2.22 KB
/
NetRunner.cpp
File metadata and controls
72 lines (66 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "NetRunner.hpp"
NetRunner::NetRunner(std::shared_ptr<Net> net, std::shared_ptr<torch::Device> device, int batch_size)
{
this->net = net;
this->device = device;
this->done = false;
this->runner = std::thread(&NetRunner::run_loop, this);
this->cnt = 0;
this->batch_size = batch_size;
this->output = std::vector<torch::Tensor>(batch_size);
}
torch::Tensor NetRunner::request_run(torch::Tensor input)
{
std::unique_lock<std::mutex> lck(this->mtx);
std::condition_variable cv;
bool ready = false;
this->to_process.push_back(input);
int request_pos = this->cnt;
this->cnt = (this->cnt + 1) % this->batch_size;
this->to_process_cv.push_back(&cv);
this->ready_list.push_back(&ready);
this->runner_cv.notify_one();
cv.wait(lck, [&]{return ready;});
return this->output[request_pos];
}
void NetRunner::shutdown()
{
this->done = true;
this->runner_cv.notify_one();
this->runner.join();
}
void NetRunner::run_loop()
{
torch::NoGradGuard no_grad;
while (!this->done)
{
std::unique_lock<std::mutex> lck(this->mtx);
this->runner_cv.wait(lck, [this]{return this->done || (int)this->to_process.size() >= this->batch_size;}); // wait for a request
if (this->done)
{
break;
}
std::vector<torch::Tensor> inputs(batch_size);
for (int i = 0; i < batch_size; i++)
{
inputs[i] = this->to_process[i];
}
// std::cout << "running net" << std::endl;
torch::Tensor input = torch::cat(inputs);
torch::Tensor out = this->net->forward(input.to(*(this->device)));
out = out.to(torch::kCPU);
for (int i = 0; i < batch_size; i++)
{
this->output[i] = out.slice(0, i, i + 1);
}
for (int i = 0; i < batch_size; i++)
{
*(this->ready_list[i]) = true;
this->to_process_cv[i]->notify_one();
}
// Not incredibly efficient, but those vectors should be relatively small
this->to_process.erase(this->to_process.begin(), this->to_process.begin() + batch_size);
this->to_process_cv.erase(this->to_process_cv.begin(), this->to_process_cv.begin() + batch_size);
this->ready_list.erase(this->ready_list.begin(), this->ready_list.begin() + batch_size);
}
}