Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions infini_train/src/kernels/cuda/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
int64_t *device_offsets = nullptr;

CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream));
CUDA_CHECK(cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs,
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs,
cudaMemcpyHostToDevice));

CUDA_CHECK(cudaMallocAsync(&device_offsets, sizeof(int64_t) * (num_inputs + 1), stream));
CUDA_CHECK(cudaMemcpyAsync(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpy(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1),
cudaMemcpyHostToDevice));

ConcatForwardKernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
device_input_ptrs, static_cast<T *>(output->DataPtr()), device_offsets, N, D, num_inputs, K_total);
Expand Down Expand Up @@ -219,12 +219,11 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
int64_t *device_offsets = nullptr;

CUDA_CHECK(cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream));
CUDA_CHECK(cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice,
stream));
CUDA_CHECK(cudaMemcpy(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice));

CUDA_CHECK(cudaMallocAsync(&device_offsets, sizeof(int64_t) * (num_inputs + 1), stream));
CUDA_CHECK(cudaMemcpyAsync(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpy(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1),
cudaMemcpyHostToDevice));

ConcatBackwardKernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
static_cast<const T *>(grad_output->DataPtr()), device_ptrs, device_offsets, N, D, num_inputs, K_total);
Expand Down
5 changes: 2 additions & 3 deletions infini_train/src/kernels/cuda/elementwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ void LaunchForward(Func func, const std::shared_ptr<Tensor> &output, const Input
host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end());
host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end());

cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice,
cuda_stream);
cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice);

LaunchKernel<BLOCK_SIZE, T>(
[&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) {
Expand Down Expand Up @@ -554,7 +553,7 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end());
host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end());

cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice);

const size_t num_elements = grad_output->NumElements();

Expand Down
18 changes: 7 additions & 11 deletions infini_train/src/kernels/cuda/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,9 @@ std::shared_ptr<Tensor> IndexGatherForward(const std::shared_ptr<Tensor> &input,
int64_t *in_strides_dev = dev_buf + 1 * num_dims;
int64_t *out_strides_dev = dev_buf + 2 * num_dims;

CUDA_CHECK(
cudaMemcpyAsync(out_dims_dev, idx_dims.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(
cudaMemcpyAsync(in_strides_dev, in_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(out_strides_dev, out_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice,
stream));
CUDA_CHECK(cudaMemcpy(out_dims_dev, idx_dims.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(in_strides_dev, in_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(out_strides_dev, out_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice));

const int threads = 256;
const int blocks = (total_elements + threads - 1) / threads;
Expand Down Expand Up @@ -198,11 +195,10 @@ std::shared_ptr<Tensor> IndexGatherBackward(const std::shared_ptr<Tensor> &grad_
int64_t *in_strides_dev = out_dims_dev + n_out;
int64_t *out_strides_dev = in_strides_dev + n_in_strides;

CUDA_CHECK(cudaMemcpyAsync(out_dims_dev, idx_dims.data(), n_out * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(in_strides_dev, in_strides.data(), n_in_strides * sizeof(int64_t),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(out_strides_dev, out_strides.data(), n_out_strides * sizeof(int64_t),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpy(out_dims_dev, idx_dims.data(), n_out * sizeof(int64_t), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(in_strides_dev, in_strides.data(), n_in_strides * sizeof(int64_t), cudaMemcpyHostToDevice));
CUDA_CHECK(
cudaMemcpy(out_strides_dev, out_strides.data(), n_out_strides * sizeof(int64_t), cudaMemcpyHostToDevice));

const int threads = 256;
const int blocks = (int)((total_elements + threads - 1) / threads);
Expand Down
24 changes: 10 additions & 14 deletions infini_train/src/kernels/cuda/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
input_strides_dev = steps_dev + steps.size();
output_strides_dev = input_strides_dev + dims.size();

cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice,
stream);
cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice,
stream);
cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice);

int threads_per_block = 256;
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
Expand Down Expand Up @@ -175,13 +173,11 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
input_strides_dev = steps_dev + steps.size();
output_strides_dev = input_strides_dev + dims.size();

cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice,
stream);
cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice,
stream);
cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice);

int threads_per_block = 256;
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
Expand Down
5 changes: 2 additions & 3 deletions infini_train/src/kernels/cuda/split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,13 @@ std::shared_ptr<Tensor> LaunchSplitBackward(const std::vector<int64_t> &input_di
device_grad_output_ptrs = (const T **)(device_ptr);
device_H_outs = reinterpret_cast<int64_t *>(device_grad_output_ptrs + num_splits);

cudaMemcpyAsync(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits,
cudaMemcpyHostToDevice, stream);
cudaMemcpy(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, cudaMemcpyHostToDevice);

// init H_out for each split
std::vector<int64_t> H_outs(num_splits);
for (int i = 0; i < num_splits; ++i) { H_outs[i] = std::min(split_size, H_in - i * split_size); }

cudaMemcpyAsync(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice, stream);
cudaMemcpy(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice);

int64_t total_elements = N * H_in * W;
int threads_per_block = 256;
Expand Down
5 changes: 2 additions & 3 deletions infini_train/src/kernels/cuda/stack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ std::shared_ptr<Tensor> StackForward(const std::vector<std::shared_ptr<Tensor>>

const T **device_input_ptrs;
cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream);
cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice,
stream);
cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice);

StackForwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
device_input_ptrs, static_cast<T *>(output->DataPtr()), N, D, num_inputs);
Expand Down Expand Up @@ -137,7 +136,7 @@ std::vector<std::shared_ptr<Tensor>> StackBackward(const std::vector<int64_t> &i

T **device_ptrs;
cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream);
cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, stream);
cudaMemcpy(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice);

StackBackwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
static_cast<const T *>(grad_output->DataPtr()), device_ptrs, N, D, num_inputs);
Expand Down
2 changes: 1 addition & 1 deletion infini_train/src/kernels/cuda/transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ std::shared_ptr<Tensor> TransposeForward(const std::shared_ptr<Tensor> &input, i
host_buffer.insert(host_buffer.end(), in_strides.begin(), in_strides.end());
host_buffer.insert(host_buffer.end(), out_strides.begin(), out_strides.end());

cudaMemcpyAsync(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpy(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice);

int threads_per_block = 256;
int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block;
Expand Down
23 changes: 9 additions & 14 deletions infini_train/src/nn/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean
core::DeviceGuard guard(device);
auto impl = core::GetDeviceGuardImpl(device.type());

impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
impl->GetStream(device));
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);
return tensor;
}

Expand Down Expand Up @@ -143,9 +142,8 @@ std::shared_ptr<Tensor> Uniform(const std::shared_ptr<Tensor> &tensor, float a,
core::DeviceGuard guard(device);
auto impl = core::GetDeviceGuardImpl(device.type());

impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
impl->GetStream(device));
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);

return tensor;
}
Expand All @@ -161,9 +159,8 @@ std::shared_ptr<Tensor> Ones(const std::shared_ptr<Tensor> &tensor) {

auto impl = core::GetDeviceGuardImpl(device.type());

impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
impl->GetStream(device));
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);

return tensor;
}
Expand All @@ -179,9 +176,8 @@ std::shared_ptr<Tensor> Zeros(const std::shared_ptr<Tensor> &tensor) {

auto impl = core::GetDeviceGuardImpl(device.type());

impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
impl->GetStream(device));
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);

return tensor;
}
Expand All @@ -190,7 +186,7 @@ std::shared_ptr<Tensor> Zeros(const std::shared_ptr<Tensor> &tensor) {
case DATA_TYPE: { \
std::vector<TYPE> buffer(num_elements); \
std::iota(buffer.begin(), buffer.end(), static_cast<TYPE>(start)); \
impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind, stream); \
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind); \
break; \
}

Expand All @@ -202,7 +198,6 @@ std::shared_ptr<Tensor> Arange(int64_t start, int64_t end, DataType dtype, Devic
auto *impl = core::GetDeviceGuardImpl(device.type());

const core::MemcpyKind kind = device.IsCPU() ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D;
core::Stream *stream = impl->GetStream(device);

switch (dtype) {
ARANGE_CASE(DataType::kUINT8, uint8_t)
Expand Down
22 changes: 9 additions & 13 deletions infini_train/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ Tensor::Tensor(const float *data, const std::vector<int64_t> &dims, DataType dty

core::DeviceGuard guard(device);
auto *impl = core::GetDeviceGuardImpl(device.type());
impl->MemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
impl->GetStream(device));
impl->Memcpy(buffer_->DataPtr(), data, buffer_->Size(),
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);
}

void Tensor::SetData(const Tensor &tensor, size_t offset, bool preserve_data) {
Expand Down Expand Up @@ -162,16 +161,14 @@ Tensor Tensor::To(Device device) {
new_tensor = Tensor(dims_, dtype_, Device());
core::DeviceGuard guard(buffer_device);
auto impl = core::GetDeviceGuardImpl(buffer_device.type());
impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H,
impl->GetStream(buffer_device));
impl->Memcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H);

} else if (buffer_device.type() == Device::DeviceType::kCPU) {
new_tensor = Tensor(dims_, dtype_, device);
// H2D
core::DeviceGuard guard(device);
auto *impl = core::GetDeviceGuardImpl(device.type());
impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D,
impl->GetStream(device));
impl->Memcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D);
} else {
new_tensor = Tensor(dims_, dtype_, device);
// P2P
Expand All @@ -180,8 +177,7 @@ Tensor Tensor::To(Device device) {
// 2. H2D
core::DeviceGuard guard(buffer_device);
auto *impl = core::GetDeviceGuardImpl(buffer_device.type());
impl->MemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D,
impl->GetStream(buffer_device));
impl->Memcpy(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D);
}

if (grad_) {
Expand Down Expand Up @@ -230,17 +226,17 @@ void Tensor::CopyFrom(const Tensor &src) {
if (dst_dev == src_dev) {
core::DeviceGuard guard(dst_dev);
auto *impl = core::GetDeviceGuardImpl(dst_dev.type());
impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D, impl->GetStream(dst_dev));
impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D);
} else if (dst_dev.type() == Device::DeviceType::kCPU) {
// D2H
core::DeviceGuard guard(src_dev);
auto *impl = core::GetDeviceGuardImpl(src_dev.type());
impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H, impl->GetStream(src_dev));
impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H);
} else if (src_dev.type() == Device::DeviceType::kCPU) {
// H2D
core::DeviceGuard guard(dst_dev);
auto *impl = core::GetDeviceGuardImpl(dst_dev.type());
impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev));
impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D);
} else {
// TODO(dcj): maybe support p2p api later
// P2P
Expand All @@ -250,7 +246,7 @@ void Tensor::CopyFrom(const Tensor &src) {
// 2. H2D
core::DeviceGuard guard(dst_dev);
auto *impl = core::GetDeviceGuardImpl(dst_dev.type());
impl->MemcpyAsync(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev));
impl->Memcpy(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D);
}
}

Expand Down