Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion runtime/include/brt/core/framework/execution_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ExecutionPlan {
* It selects ExecutionProviders based ranking.
*/

class StaticBRTExecutionPlan final : public ExecutionPlan {
class StaticBRTExecutionPlan : public ExecutionPlan {
public:
StaticBRTExecutionPlan(brt::ir::ByREHandle &);

Expand Down Expand Up @@ -162,4 +162,46 @@ class StaticBRTExecutionPlan final : public ExecutionPlan {
std::vector<OpKernel *> compute_op_kernels_;
};

class MultiStreamExecutionPlan : public StaticBRTExecutionPlan {

public:
MultiStreamExecutionPlan(brt::ir::ByREHandle &);

common::Status ProloguePerSession(
const std::unordered_map<std::string, std::unique_ptr<IAllocator>>
&allocators,
const std::vector<std::unique_ptr<ExecutionProvider>> &providers,
const Device dev, const DeviceAPI *device_api) override;

common::Status EpiloguePerSession() override;

void CreateWorkQueue(std::unique_ptr<WorkQueue> *wq, int rank = 0) override;

void CreateExecutinFrame(std::unique_ptr<ExecutionFrame> *frame) override;

common::Status ProloguePerFrame(const ExecutionContext &) override;
common::Status EpiloguePerFrame(const ExecutionContext &) override;

common::Status Run(const ExecutionContext &) override;

using PartitionGraphMethod = std::function<int(OpKernel *kernel)>;

void SetPartitionGraphMethod(PartitionGraphMethod method) {
partition_graph_method_ = method;
}

private:
std::unordered_map<OpKernel *, int> kernel_stream_map_;
std::vector<std::vector<OpKernel *>> logical_streams_;
std::vector<cudaStream_t> cuda_streams_;

void PartitionOpKernels(PartitionGraphMethod method);
std::unordered_map<OpKernel *, int64_t> kernel_to_event_index;
std::unordered_map<OpKernel *, std::vector<int64_t>> kernel_to_wait_events;
std::vector<cudaEvent_t> event_list_;

void AnalyzeStreamDependency();

int max_stream_num_;

} // namespace brt
87 changes: 87 additions & 0 deletions runtime/lib/core/framework/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,4 +699,91 @@ void StaticBRTExecutionPlan::IterateOpKernels(
}
}

MultiStreamExecutionPlan::MultiStreamExecutionPlan(ByREHandle &graph)
: StaticBRTExecutionPlan(graph) {}

void MultiStreamExecutionPlan::PartitionOpKernels(PartitionGraphMethod method) {
int64_t max_stream_id_ = 0;
for (auto kernel : compute_op_kernels_) {
int stream_id = method(kernel);
max_stream_id_ = std::max(max_stream_id_, stream_id);
kernel_stream_map_[kernel] = stream_id;
}

num_streams_ = max_stream_id_ + 1;
logical_streams_.resize(num_streams_);
cuda_streams_.resize(num_streams_);

for (auto kernel : compute_op_kernels_) {
int stream_id = kernel_stream_map_[kernel];
logical_streams_[stream_id].push_back(kernel);
}

for (int i = 0; i < num_streams_; ++i) {
cudaStreamCreate(&cuda_streams_[i]);
}
}

void MultiStreamExecutionPlan::AnalyzeStreamDependency() {
for (int i = 0; i < num_streams_; ++i) {
for (auto kernel : logical_streams_[i]) {
for (auto dep_index : kernel->GetDependencyList()) {
auto dep_kernel = op_kernels_[dep_index];
int dep_stream_id = kernel_stream_map_[dep_kernel];
if (dep_stream_id != i) {
if (kernel_to_event_index.find(dep_kernel) ==
kernel_to_event_index.end()) {
kernel_to_event_index[dep_kernel] = event_list.size();
event_list.push_back(cudaEvent_t());
}
if (kernel_to_wait_events.find(kernel) ==
kernel_to_wait_events.end()) {
kernel_to_wait_events[kernel] = {};
}
kernel_to_wait_events[kernel].push_back(
kernel_to_event_index[dep_kernel]);
}
}
}
}
}

void MultiStreamExecutionPlan::Run(const ExecutionContext &context) {
context.event_listener_manager->SignalEvent<Events::BeforeExecutionPlanRun>(
{});

std::vector<ExecutionContext> stream_contexts(num_streams_, context);
for (int i = 0; i < num_streams_; ++i) {
stream_contexts[i].stream = cuda_streams_[i];
}

for (auto op : shape_op_kernels_) {
common::Status status = op->Run(context);
if (!status.IsOK()) {
return status;
}
}

for (auto op : compute_op_kernels_) {
int stream_id = kernel_stream_map_[op];
if (kernel_to_wait_events.find(op) != kernel_to_wait_events.end()) {
for (auto event_index : kernel_to_wait_events[op]) {
cudaStreamWaitEvent(cuda_streams_[stream_id], event_list[event_index],
0);
}
}
common::Status status = op->Run(stream_contexts[stream_id]);
if (!status.IsOK()) {
return status;
}
if (kernel_to_event_index.find(op) != kernel_to_event_index.end()) {
cudaEventCreate(&event_list[kernel_to_event_index[op]]);
cudaEventRecord(event_list[kernel_to_event_index[op]],
cuda_streams_[stream_id]);
}
}

context.event_listener_manager->SignalEvent<Events::AfterExecutionPlanRun>(
{});

} // namespace brt