From 298267a4cc5133fc821ac6054f5556c85780886c Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 14 Jan 2026 17:24:52 +0000 Subject: [PATCH 1/7] feat: implement device registration infrastructure --- infini_train/include/core/blas_handle.h | 11 + infini_train/include/core/device_guard.h | 209 ++++++++++++++++++ infini_train/include/core/stream.h | 10 + infini_train/include/device.h | 99 ++------- infini_train/src/core/cpu/cpu_guard.cc | 18 ++ infini_train/src/core/cpu/cpu_guard.h | 22 ++ .../src/core/cuda/cuda_blas_handle.cc | 15 ++ infini_train/src/core/cuda/cuda_blas_handle.h | 21 ++ infini_train/src/core/cuda/cuda_guard.cc | 131 +++++++++++ infini_train/src/core/cuda/cuda_guard.h | 54 +++++ infini_train/src/core/cuda/cuda_stream.cc | 12 + infini_train/src/core/cuda/cuda_stream.h | 19 ++ infini_train/src/core/device_guard.cc | 149 +++++++++++++ infini_train/src/device.cc | 115 +--------- .../src/kernels/cuda/accumulate_grad.cu | 8 +- 15 files changed, 707 insertions(+), 186 deletions(-) create mode 100644 infini_train/include/core/blas_handle.h create mode 100644 infini_train/include/core/device_guard.h create mode 100644 infini_train/include/core/stream.h create mode 100644 infini_train/src/core/cpu/cpu_guard.cc create mode 100644 infini_train/src/core/cpu/cpu_guard.h create mode 100644 infini_train/src/core/cuda/cuda_blas_handle.cc create mode 100644 infini_train/src/core/cuda/cuda_blas_handle.h create mode 100644 infini_train/src/core/cuda/cuda_guard.cc create mode 100644 infini_train/src/core/cuda/cuda_guard.h create mode 100644 infini_train/src/core/cuda/cuda_stream.cc create mode 100644 infini_train/src/core/cuda/cuda_stream.h create mode 100644 infini_train/src/core/device_guard.cc diff --git a/infini_train/include/core/blas_handle.h b/infini_train/include/core/blas_handle.h new file mode 100644 index 00000000..56b058ce --- /dev/null +++ b/infini_train/include/core/blas_handle.h @@ -0,0 +1,11 @@ +#pragma once + +namespace infini_train::core { + +class BlasHandle { +public: + BlasHandle(){}; + virtual ~BlasHandle() = default; +}; + +} // namespace infini_train::core diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h new file mode 100644 index 00000000..af37e42a --- /dev/null +++ b/infini_train/include/core/device_guard.h @@ -0,0 +1,209 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/device.h" + +namespace infini_train::core { + +class Stream; +class BlasHandle; + +enum class MemcpyKind : int8_t { + kH2D = 0, + kD2H = 1, + kD2D = 2, + kInvalid = -1, +}; + +// +// ---------------------------------------------------------------------------- +// DeviceGuardImpl: Backend-specific device/stream/memory/BLAS implementation +// ---------------------------------------------------------------------------- +// This is the low-level virtual interface that each backend must implement. +// Examples: +// - CUDA: CudaDeviceGuardImpl +// - CPU: CpuDeviceGuardImpl +// - Custom: MyChipDeviceGuardImpl +// +// DeviceGuardImpl encapsulates **all device-runtime behaviors**, including: +// +// • Querying / setting the current device +// • Stream creation/lookup +// • Synchronization primitives +// • Memory allocation & copy +// • Access to BLAS handles +// +// DeviceGuard (the public RAII wrapper) forwards calls to the DeviceGuardImpl +// instance registered for the device type. +// +// TODO(zbl): add event managemnt +// +class DeviceGuardImpl { +public: + DeviceGuardImpl() {} + + virtual ~DeviceGuardImpl() = default; + + // ---------------------------------------------------------------------- + // Device management + // ---------------------------------------------------------------------- + + virtual Device GetDevice() const = 0; + + virtual void SetDevice(Device device) const; + + virtual int8_t DeviceCount() const; + + virtual Device::DeviceType Type() const = 0; + + // ---------------------------------------------------------------------- + // Stream management + // ---------------------------------------------------------------------- + + virtual Stream *GetStream(Device) const; + + // ---------------------------------------------------------------------- + // Synchronization + // ---------------------------------------------------------------------- + + virtual void SynchronizeDevice(Device) const; + + virtual void SynchronizeStream(Stream *) const; + + // ---------------------------------------------------------------------- + // BLAS handle + // ---------------------------------------------------------------------- + + virtual BlasHandle *GetBlasHandle(Device) const; + + // ---------------------------------------------------------------------- + // Memory operations + // ---------------------------------------------------------------------- + + virtual void Malloc(void **dev_ptr, size_t size) = 0; + + virtual void MallocAsync(void **dev_ptr, size_t size, Stream *stream); + + virtual void Free(void *dev_ptr) = 0; + + virtual void FreeAsync(void *dev_ptr, Stream *stream); + + virtual void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) = 0; + + virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream); + + virtual void ResetMemPoolHighWatermarks() const; + + virtual std::pair GetMemPoolPeakMB() const; +}; + +// +// ---------------------------------------------------------------------------- +// DeviceGuard: RAII front-end wrapper for DeviceGuardImpl +// ---------------------------------------------------------------------------- +// This class is the **public-facing device interface** for the framework. +// It automatically: +// +// • Saves the current device on construction +// • Switches to the target device +// • Restores the previous device on destruction +// +// All runtime operations (memory, streams, BLAS, sync) are forwarded to the +// backend-specific DeviceGuardImpl registered for that device type. +// +class DeviceGuard { +public: + explicit DeviceGuard(Device device); + + ~DeviceGuard(); + + DeviceGuard(const DeviceGuard &) = delete; + DeviceGuard &operator=(const DeviceGuard &) = delete; + + // Device operations + Device GetDevice() const; + + void SetDevice(Device device) const; + + int8_t DeviceCount() const; + + Device::DeviceType Type() const; + + // Stream operations + Stream *GetStream(Device) const; + + // Synchronization + void SynchronizeDevice(Device) const; + + void SynchronizeStream(Stream *) const; + + // BLAS + BlasHandle *GetBlasHandle(Device) const; + + // Memory operations + void Malloc(void **dev_ptr, size_t size); + + void MallocAsync(void **dev_ptr, size_t size, Stream *stream); + + void Free(void *dev_ptr); + + void FreeAsync(void *dev_ptr, Stream *stream); + + void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind); + + void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream); + +private: + DeviceGuardImpl *impl_ = nullptr; + Device original_device_; +}; + +// +// ---------------------------------------------------------------------------- +// DeviceGuardImplRegistry: Global registry of backend implementations +// ---------------------------------------------------------------------------- +// This registry stores at most one DeviceGuardImpl per DeviceType. +// Backends register themselves at static initialization time via the macro +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(). +// +// Example: +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(DeviceType::kCUDA, CudaGuardImpl) +// +class DeviceGuardImplRegistry { +public: + static DeviceGuardImplRegistry &Instance(); + + void Register(Device::DeviceType type, std::unique_ptr impl); + + DeviceGuardImpl *Get(Device::DeviceType type) const; + +private: + DeviceGuardImplRegistry() = default; + DeviceGuardImplRegistry(const DeviceGuardImplRegistry &) = delete; + DeviceGuardImplRegistry &operator=(const DeviceGuardImplRegistry &) = delete; + + std::unordered_map> impls_; +}; + +DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type); + +} // namespace infini_train::core + +// +// ---------------------------------------------------------------------------- +// Registration macro +// ---------------------------------------------------------------------------- +// Registers a DeviceGuardImpl implementation into the global registry +// at static initialization time. +// +// Example usage: +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(DeviceType::kCUDA, CudaGuardImpl) +// +#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \ + static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \ + infini_train::core::DeviceGuardImplRegistry::Instance().Register(device_type, std::make_unique()); \ + return true; \ + }(); diff --git a/infini_train/include/core/stream.h b/infini_train/include/core/stream.h new file mode 100644 index 00000000..190298f6 --- /dev/null +++ b/infini_train/include/core/stream.h @@ -0,0 +1,10 @@ +#pragma once + +namespace infini_train::core { + +class Stream { +public: + virtual ~Stream() = default; +}; + +} // namespace infini_train::core diff --git a/infini_train/include/device.h b/infini_train/include/device.h index 6537c3f5..d4dd76ff 100644 --- a/infini_train/include/device.h +++ b/infini_train/include/device.h @@ -1,104 +1,43 @@ #pragma once #include -#include -#include -#include - -#ifdef USE_CUDA -#include -#endif - -#include "glog/logging.h" +#include +#include #include "infini_train/include/nn/parallel/rank.h" namespace infini_train { -enum class DeviceType : int8_t { - kCPU = 0, - kCUDA = 1, - kCount = 2, -}; - -class DeviceManager; - class Device { public: - virtual ~Device() = default; + enum class DeviceType : int8_t { + kCPU = 0, + kCUDA = 1, + kInvalid = -1, + }; - DeviceType Type() const; - int8_t Index() const; + Device(); + Device &operator=(const Device &) = default; - bool IsCPU() const; - bool IsCUDA() const; - - virtual void SetDevice() const {} - virtual void Synchronize() const {} - - std::string ToString() const; - - virtual nn::parallel::Rank rank() const; - - friend std::ostream &operator<<(std::ostream &os, const Device &device); - -protected: Device(DeviceType type, int8_t index); - DeviceType type_; - int8_t index_; -}; + ~Device() = default; -class CpuDevice : public Device { -private: - CpuDevice(); - - friend class DeviceManager; -}; + DeviceType type() const; + int8_t index() const; -#ifdef USE_CUDA -class CudaDevice : public Device { -public: - ~CudaDevice() override; - - void SetDevice() const override; - void Synchronize() const override; - - cudaStream_t Stream() const; + bool IsCPU() const; + bool IsCUDA() const; - cublasHandle_t CublasHandle() const; + std::string ToString() const; - nn::parallel::Rank rank() const override; + virtual nn::parallel::Rank Rank() const; - void ResetMemPoolHighWatermarks() const; - std::pair GetMemPoolPeakMB() const; + friend std::ostream &operator<<(std::ostream &os, const Device &device); private: - CudaDevice(int8_t index); - - cudaStream_t stream_ = nullptr; - - cublasHandle_t cublas_handle_ = nullptr; - - nn::parallel::Rank rank_; - - friend class DeviceManager; + DeviceType type_ = DeviceType::kInvalid; + int8_t index_ = -1; }; -#endif - -class DeviceManager { -public: - static const DeviceManager *Instance(); - - const Device *GetDevice(DeviceType type, int8_t index = 0) const; - - const Device *GetDefaultDevice() const; - std::vector GetAllAvailableDevices(DeviceType device_type) const; - -private: - DeviceManager(); - - std::unordered_map>> devices_map_; -}; } // namespace infini_train diff --git a/infini_train/src/core/cpu/cpu_guard.cc b/infini_train/src/core/cpu/cpu_guard.cc new file mode 100644 index 00000000..6d98d30f --- /dev/null +++ b/infini_train/src/core/cpu/cpu_guard.cc @@ -0,0 +1,18 @@ +#include "infini_train/src/core/cpu/cpu_guard.h" + +#include +#include + +namespace infini_train::core::cpu { + +Device CpuGuardImpl::GetDevice() const { return Device(Device::DeviceType::kCPU, 0); } + +Device::DeviceType CpuGuardImpl::Type() const { return Device::DeviceType::kCPU; } + +void CpuGuardImpl::Malloc(void **dev_ptr, size_t size) { *dev_ptr = std::malloc(size); } + +void CpuGuardImpl::Free(void *dev_ptr) { std::free(dev_ptr); } + +void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { std::memcpy(dst, src, count); } + +} // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cpu/cpu_guard.h b/infini_train/src/core/cpu/cpu_guard.h new file mode 100644 index 00000000..3b6ac71f --- /dev/null +++ b/infini_train/src/core/cpu/cpu_guard.h @@ -0,0 +1,22 @@ +#pragma once + +#include "infini_train/include/core/device_guard.h" + +namespace infini_train::core::cpu { + +class CpuGuardImpl : public DeviceGuardImpl { +public: + CpuGuardImpl(); + + Device GetDevice() const; + + Device::DeviceType Type() const; + + void Malloc(void **dev_ptr, size_t size); + + void Free(void *dev_ptr); + + void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind); +}; + +} // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cuda/cuda_blas_handle.cc b/infini_train/src/core/cuda/cuda_blas_handle.cc new file mode 100644 index 00000000..38fe04cb --- /dev/null +++ b/infini_train/src/core/cuda/cuda_blas_handle.cc @@ -0,0 +1,15 @@ + +#include "infini_train/src/core/cuda/cuda_blas_handle.h" + +#include "infini_train/include/common/cuda/common_cuda.h" + +#include "infini_train/src/core/cuda/cuda_stream.h" + +namespace infini_train::core::cuda { + +CudaBlasHandle::CudaBlasHandle(Stream *stream) { + CUBLAS_CHECK(cublasCreate(&cublas_handle_)); + CUBLAS_CHECK(cublasSetStream(cublas_handle_, dynamic_cast(stream)->cuda_stream())); +} + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_blas_handle.h b/infini_train/src/core/cuda/cuda_blas_handle.h new file mode 100644 index 00000000..86e1a53a --- /dev/null +++ b/infini_train/src/core/cuda/cuda_blas_handle.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "infini_train/include/core/blas_handle.h" + +namespace infini_train::core { +class Stream; +} + +namespace infini_train::core::cuda { + +class CudaBlasHandle : public BlasHandle { +public: + explicit CudaBlasHandle(Stream *stream); + +private: + cublasHandle_t cublas_handle_; +}; + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_guard.cc b/infini_train/src/core/cuda/cuda_guard.cc new file mode 100644 index 00000000..ae0b34ef --- /dev/null +++ b/infini_train/src/core/cuda/cuda_guard.cc @@ -0,0 +1,131 @@ +#include "infini_train/src/core/cuda/cuda_guard.h" + +#include +#include +#include +#include + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/blas_handle.h" +#include "infini_train/include/device.h" + +#include "infini_train/src/core/cuda/cuda_blas_handle.h" +#include "infini_train/src/core/cuda/cuda_stream.h" + +namespace infini_train::core::cuda { +namespace { +constexpr int kMaxGpus = 8; + +static std::array, kMaxGpus> cuda_streams; +static std::array, kMaxGpus> cuda_blas_handles; + +static std::array device_stream_flags; +static std::array device_handle_flags; +} // namespace + +void CudaGuardImpl::InitSingleStream(Device device) { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + CUDA_CHECK(cudaSetDevice(device.index())); + + cuda_streams[device.index()] = std::make_unique(); + + CUDA_CHECK(cudaSetDevice(current_device)); +} + +void CudaGuardImpl::InitSingleHandle(Device device) { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + CUDA_CHECK(cudaSetDevice(device.index())); + + std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device.index()); + + cuda_blas_handles[device.index()] = std::make_unique(cuda_streams[device.index()].get()); + + CUDA_CHECK(cudaSetDevice(current_device)); +} + +CudaGuardImpl::CudaGuardImpl() {} + +// device +Device CudaGuardImpl::GetDevice() const { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + return Device(Device::DeviceType::kCUDA, current_device); +} + +void CudaGuardImpl::SetDevice(Device device) const { CUDA_CHECK(cudaSetDevice(device.index())); } + +int8_t CudaGuardImpl::DeviceCount() const { + int device_count = 0; + CUDA_DRIVER_CHECK(cuDeviceGetCount(&device_count)); + return device_count; +} + +Device::DeviceType CudaGuardImpl::Type() const { return Device::DeviceType::kCUDA; } + +// stream +Stream *CudaGuardImpl::GetStream(Device device) const { + std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device); + return cuda_streams.at(device.index()).get(); +} + +// event + +// sync +void CudaGuardImpl::SynchronizeDevice(Device device) const { + auto original_device = GetDevice(); + SetDevice(device); + + CUDA_CHECK(cudaDeviceSynchronize()); + + SetDevice(original_device); +} + +// blas +BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const { + std::call_once(device_handle_flags.at(device.index()), InitSingleStream, device); + return cuda_blas_handles.at(device.index()).get(); +} + +// memory +void CudaGuardImpl::Malloc(void **dev_ptr, size_t size) { CUDA_CHECK(cudaMalloc(dev_ptr, size)); } + +void CudaGuardImpl::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { + CUDA_CHECK(cudaMallocAsync(dev_ptr, size, dynamic_cast(stream)->cuda_stream())); +} + +void CudaGuardImpl::Free(void *dev_ptr) { CUDA_CHECK(cudaFree(dev_ptr)); } + +void CudaGuardImpl::FreeAsync(void *dev_ptr, Stream *stream) { + CUDA_CHECK(cudaFreeAsync(dev_ptr, dynamic_cast(stream)->cuda_stream())); +} + +void CudaGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { + if (kind == MemcpyKind::kH2D) { + CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyHostToDevice)); + } else if (kind == MemcpyKind::kD2H) { + CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToHost)); + } else if (kind == MemcpyKind::kD2D) { + CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToDevice)); + } else { + LOG(FATAL) << "Invalid MemcpyKind"; + } +} + +void CudaGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { + cudaStream_t cuda_stream = dynamic_cast(stream)->cuda_stream(); + if (kind == MemcpyKind::kH2D) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyHostToDevice, cuda_stream)); + } else if (kind == MemcpyKind::kD2H) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToHost, cuda_stream)); + } else if (kind == MemcpyKind::kD2D) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, cuda_stream)); + } else { + LOG(FATAL) << "Invalid MemcpyKind"; + } +} + +INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_guard.h b/infini_train/src/core/cuda/cuda_guard.h new file mode 100644 index 00000000..e8360025 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_guard.h @@ -0,0 +1,54 @@ +#pragma once + +#include + +#include "infini_train/include/core/blas_handle.h" +#include "infini_train/include/core/device_guard.h" +#include "infini_train/include/core/stream.h" +#include "infini_train/include/device.h" + +namespace infini_train::core::cuda { + +class CudaGuardImpl : public DeviceGuardImpl { +public: + static void InitSingleStream(Device device); + + static void InitSingleHandle(Device device); + + CudaGuardImpl(); + + // device + Device GetDevice() const override; + + void SetDevice(Device device) const override; + + int8_t DeviceCount() const override; + + Device::DeviceType Type() const override; + + // stream + Stream *GetStream(Device device) const override; + + // event + + // sync + void SynchronizeDevice(Device device) const override; + + // blas + BlasHandle *GetBlasHandle(Device device) const override; + + // memory + void Malloc(void **dev_ptr, size_t size) override; + + void MallocAsync(void **dev_ptr, size_t size, Stream *stream) override; + + void Free(void *dev_ptr) override; + + void FreeAsync(void *dev_ptr, Stream *stream) override; + + void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) override; + + void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) override; +}; + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_stream.cc b/infini_train/src/core/cuda/cuda_stream.cc new file mode 100644 index 00000000..89319d08 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_stream.cc @@ -0,0 +1,12 @@ +#include "infini_train/src/core/cuda/cuda_stream.h" + +#include + +#include "infini_train/include/common/cuda/common_cuda.h" + +namespace infini_train::core::cuda { +CudaStream::CudaStream() { CUDA_CHECK(cudaStreamCreate(&stream_)); } + +cudaStream_t CudaStream::cuda_stream() { return stream_; } + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_stream.h b/infini_train/src/core/cuda/cuda_stream.h new file mode 100644 index 00000000..40dc7235 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_stream.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "infini_train/include/core/stream.h" + +namespace infini_train::core::cuda { + +class CudaStream : public Stream { +public: + CudaStream(); + + cudaStream_t cuda_stream(); + +private: + cudaStream_t stream_; +}; + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/device_guard.cc b/infini_train/src/core/device_guard.cc new file mode 100644 index 00000000..fb210d75 --- /dev/null +++ b/infini_train/src/core/device_guard.cc @@ -0,0 +1,149 @@ +#include "infini_train/include/core/device_guard.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/blas_handle.h" +#include "infini_train/include/core/stream.h" +#include "infini_train/src/core/cpu/cpu_guard.h" + +namespace infini_train::core { + +// DeviceGuardImpl +void DeviceGuardImpl::SetDevice(Device device) const { + LOG(WARNING) << std::format("SetDevice is not supported for device type {} (index {}). " + "The call is ignored.", + static_cast(device.type()), device.index()); +} + +int8_t DeviceGuardImpl::DeviceCount() const { return -1; } + +Stream *DeviceGuardImpl::GetStream(Device) const { return nullptr; } + +void DeviceGuardImpl::SynchronizeDevice(Device device) const { + LOG(WARNING) << std::format("SynchronizeDevice is not supported for this device. " + "The call is ignored.", + static_cast(device.type()), device.index()); +} + +void DeviceGuardImpl::SynchronizeStream(Stream *) const { + LOG(WARNING) << "SynchronizeStream is not supported for this device. " + "The call is ignored."; +} + +BlasHandle *DeviceGuardImpl::GetBlasHandle(Device device) const { + LOG(FATAL) << std::format("GetBlasHandle is not supported for device type {} (index {}). ", + static_cast(device.type()), device.index()); +} + +void DeviceGuardImpl::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { + LOG(WARNING) << "MallocAsync is not supported on this device. Falling back to blocking Malloc()"; + Malloc(dev_ptr, size); +} + +void DeviceGuardImpl::FreeAsync(void *dev_ptr, Stream *stream) { + LOG(WARNING) << "FreeAsync is not supported on this device. Falling back to blocking Free()"; + Free(dev_ptr); +} + +void DeviceGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { + LOG(WARNING) << "MemcpyAsync is not supported on this device. Falling back to blocking Memcpy()"; + Memcpy(dst, src, count, kind); +} + +void DeviceGuardImpl::ResetMemPoolHighWatermarks() const { + LOG(WARNING) << "ResetMemPoolHighWatermarks is not supported for this device. " + "The call is ignored."; +} + +std::pair DeviceGuardImpl::GetMemPoolPeakMB() const { + LOG(WARNING) << "GetMemPoolPeakMB is not supported for this device. " + "Returning {0, 0}."; + return {0, 0}; +} + +// DeviceGuard +DeviceGuard::DeviceGuard(Device device) : impl_(GetDeviceGuardImpl(device.type())) { + original_device_ = impl_->GetDevice(); + impl_->SetDevice(device); +} + +DeviceGuard::~DeviceGuard() { impl_->SetDevice(original_device_); } + +Device DeviceGuard::GetDevice() const { return impl_->GetDevice(); } + +void DeviceGuard::SetDevice(Device device) const { return impl_->SetDevice(device); } + +int8_t DeviceGuard::DeviceCount() const { return impl_->DeviceCount(); } + +Device::DeviceType DeviceGuard::Type() const { return impl_->Type(); } + +Stream *DeviceGuard::GetStream(Device device) const { return impl_->GetStream(device); } + +void DeviceGuard::SynchronizeDevice(Device device) const { return impl_->SynchronizeDevice(device); } + +void DeviceGuard::SynchronizeStream(Stream *stream) const { return impl_->SynchronizeStream(stream); } + +BlasHandle *DeviceGuard::GetBlasHandle(Device device) const { return impl_->GetBlasHandle(device); } + +void DeviceGuard::Malloc(void **dev_ptr, size_t size) { impl_->Malloc(dev_ptr, size); } + +void DeviceGuard::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { + impl_->MallocAsync(dev_ptr, size, stream); +} + +void DeviceGuard::Free(void *dev_ptr) { impl_->Free(dev_ptr); } + +void DeviceGuard::FreeAsync(void *dev_ptr, Stream *stream) { impl_->FreeAsync(dev_ptr, stream); } + +void DeviceGuard::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { + impl_->Memcpy(dst, src, count, kind); +} + +void DeviceGuard::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { + impl_->MemcpyAsync(dst, src, count, kind, stream); +} + +// DeviceGuardImplRegistry +DeviceGuardImplRegistry &DeviceGuardImplRegistry::Instance() { + static DeviceGuardImplRegistry instance; + instance.Register(Device::DeviceType::kCPU, std::make_unique()); + return instance; +} + +void DeviceGuardImplRegistry::Register(Device::DeviceType type, std::unique_ptr impl) { + if (type != impl->Type()) { + LOG(FATAL) << std::format("Register device guard impl with type {}, but as type {}", + static_cast(impl->Type()), static_cast(type)); + } + + if (impls_.contains(type)) { + LOG(FATAL) << std::format("DeviceGuardImpl for type {} already registrered", static_cast(type)); + } + + if (!impls_.empty()) { + for (auto &kv : impls_) { + if (kv.first != Device::DeviceType::kCPU) { + LOG(FATAL) << std::format("Only CPU and one GPU backend allowed. Already have GPU={}, new={} rejected.", + static_cast(kv.first), static_cast(type)); + } + } + } + + impls_[type] = std::move(impl); +} + +DeviceGuardImpl *DeviceGuardImplRegistry::Get(Device::DeviceType type) const { + auto it = impls_.find(type); + if (it == impls_.end()) { + LOG(FATAL) << "No DeviceGuardImpl registered for type " << static_cast(type); + } + return it->second.get(); +} + +DeviceGuardImpl *GetDeviceGuard(Device::DeviceType type) { return DeviceGuardImplRegistry::Instance().Get(type); } + +} // namespace infini_train::core diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index c16aeee7..faab8e6c 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -1,14 +1,13 @@ #include "infini_train/include/device.h" #include -#include +#include +#include +#include #include "glog/logging.h" #include "infini_train/include/nn/parallel/global.h" -#ifdef USE_CUDA -#include "infini_train/include/common/cuda/common_cuda.h" -#endif namespace infini_train { namespace { @@ -21,22 +20,23 @@ Device::Device(DeviceType type, int8_t index) : type_(type), index_(index) { } } -DeviceType Device::Type() const { return type_; } -int8_t Device::Index() const { return index_; } +Device::DeviceType Device::type() const { return type_; } + +int8_t Device::index() const { return index_; } bool Device::IsCPU() const { return type_ == DeviceType::kCPU; } + bool Device::IsCUDA() const { return type_ == DeviceType::kCUDA; } std::string Device::ToString() const { std::ostringstream oss; - oss << "Device(" << (type_ == DeviceType::kCPU ? "CPU" : "CUDA") << ", " << static_cast(index_) << ")"; + oss << std::format("Device({}, {})", type_ == DeviceType::kCPU ? "CPU" : "CUDA", index_); return oss.str(); } -nn::parallel::Rank Device::rank() const { - LOG(FATAL) << "Unimplemented"; - // prevent the compiler warning about control reaching the end of non-void function - std::abort(); +nn::parallel::Rank Device::Rank() const { + return {nn::parallel::global::GetGlobalProcRank(), index_, nn::parallel::global::GetNprocPerNode(), + nn::parallel::global::GetNthreadPerProc()}; } std::ostream &operator<<(std::ostream &os, const Device &device) { @@ -44,97 +44,4 @@ std::ostream &operator<<(std::ostream &os, const Device &device) { return os; } -CpuDevice::CpuDevice() : Device(DeviceType::kCPU, 0) {} - -#ifdef USE_CUDA -CudaDevice::~CudaDevice() { - if (stream_ != nullptr) { - CUDA_CHECK(cudaStreamDestroy(stream_)); - } - - if (cublas_handle_ != nullptr) { - CUBLAS_CHECK(cublasDestroy(cublas_handle_)); - } -} - -void CudaDevice::SetDevice() const { CUDA_CHECK(cudaSetDevice(index_)); } -void CudaDevice::Synchronize() const { CUDA_CHECK(cudaDeviceSynchronize()); } - -cudaStream_t CudaDevice::Stream() const { return stream_; } - -cublasHandle_t CudaDevice::CublasHandle() const { return cublas_handle_; } - -nn::parallel::Rank CudaDevice::rank() const { return rank_; } - -CudaDevice::CudaDevice(int8_t index) - : Device(DeviceType::kCUDA, index), - rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(), - nn::parallel::global::GetNthreadPerProc()}) { - // TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode - SetDevice(); - CUDA_CHECK(cudaStreamCreate(&stream_)); - - CUBLAS_CHECK(cublasCreate(&cublas_handle_)); - CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream_)); -} - -void CudaDevice::ResetMemPoolHighWatermarks() const { - SetDevice(); - cudaMemPool_t pool; - CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, index_)); - - cuuint64_t zero = 0; - // High watermark can only be reset to zero; non-zero is illegal. - CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &zero)); - CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &zero)); -} - -std::pair CudaDevice::GetMemPoolPeakMB() const { - SetDevice(); - cudaMemPool_t pool; - CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, index_)); - - cuuint64_t used = 0; - CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &used)); - - cuuint64_t reserved = 0; - CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &reserved)); - - return std::make_pair(static_cast(used / kBytesPerMB), - static_cast(reserved / kBytesPerMB)); -} -#endif // USE_CUDA - -const DeviceManager *DeviceManager::Instance() { - static auto instance = std::unique_ptr(new DeviceManager()); - return instance.get(); -} - -const Device *DeviceManager::GetDevice(DeviceType type, int8_t index) const { - return devices_map_.at(type).at(index).get(); -} - -const Device *DeviceManager::GetDefaultDevice() const { return devices_map_.at(DeviceType::kCPU).at(0).get(); } - -std::vector DeviceManager::GetAllAvailableDevices(DeviceType device_type) const { - std::vector devices; - for (const auto &device : devices_map_.at(device_type)) { devices.push_back(device.get()); } - return devices; -} - -DeviceManager::DeviceManager() { - devices_map_[DeviceType::kCPU].push_back(std::unique_ptr(new CpuDevice())); -#ifdef USE_CUDA - CUDA_DRIVER_CHECK(cuInit(0)); - int device_count = 0; - CUDA_DRIVER_CHECK(cuDeviceGetCount(&device_count)); - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); - for (int idx = 0; idx < device_count; ++idx) { - devices_map_[DeviceType::kCUDA].push_back(std::unique_ptr(new CudaDevice(idx))); - } - CUDA_CHECK(cudaSetDevice(current_device)); -#endif -} - } // namespace infini_train diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index 2b1d486c..003ee6f4 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -2,6 +2,7 @@ #include #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -21,12 +22,15 @@ void AccumulateGrad(const std::shared_ptr &gradient, float rate, const s int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(tensor->GetDevice()); + const auto *device = tensor->GetDevice(); + + auto device_impl = GetDeviceGuardImpl(device->Type()); DispatchFunc( gradient->Dtype(), [=]() { - AccumulateGradKernel<<Stream()>>>( + AccumulateGradKernel<<(device_impl->GetStream(device))->cuda_stream()>>>( static_cast(gradient->DataPtr()), rate, static_cast(tensor->DataPtr()), num_elements); }, "CUDA AccumulateGrad"); From fb5c7439ff17c0e8c639767a6d7e77cff7b9b523 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Thu, 15 Jan 2026 15:05:56 +0000 Subject: [PATCH 2/7] refactor: remove DeviceManager usage and unify DeviceType references --- example/common/tokenizer.cc | 2 +- example/gpt2/main.cc | 14 +-- example/llama3/main.cc | 16 ++- example/llama3/net.cc | 3 +- example/llama3/net.h | 3 +- example/mnist/main.cc | 5 +- infini_train/include/autocast.h | 14 +-- infini_train/include/core/device_guard.h | 10 +- infini_train/include/device.h | 8 +- infini_train/include/dispatcher.h | 2 +- infini_train/include/nn/modules/module.h | 5 +- .../include/nn/parallel/pp/pipeline_stage.h | 2 +- infini_train/include/profiler.h | 8 +- infini_train/include/tensor.h | 20 ++-- infini_train/src/autograd/comm.cc | 2 +- infini_train/src/core/cuda/cuda_guard.cc | 27 +++++ infini_train/src/core/cuda/cuda_guard.h | 4 + infini_train/src/core/device_guard.cc | 19 ++- infini_train/src/device.cc | 8 +- .../src/kernels/cpu/accumulate_grad.cc | 2 +- infini_train/src/kernels/cpu/cast.cc | 2 +- infini_train/src/kernels/cpu/concat.cc | 2 +- infini_train/src/kernels/cpu/cross_entropy.cc | 2 +- infini_train/src/kernels/cpu/elementwise.cc | 2 +- infini_train/src/kernels/cpu/embedding.cc | 2 +- infini_train/src/kernels/cpu/fill.cc | 2 +- infini_train/src/kernels/cpu/gather.cc | 2 +- infini_train/src/kernels/cpu/layernorm.cc | 2 +- infini_train/src/kernels/cpu/linear.cc | 2 +- infini_train/src/kernels/cpu/no_op.cc | 2 +- infini_train/src/kernels/cpu/outer.cc | 2 +- infini_train/src/kernels/cpu/reduction.cc | 2 +- infini_train/src/kernels/cpu/sigmoid.cc | 2 +- infini_train/src/kernels/cpu/slice.cc | 2 +- infini_train/src/kernels/cpu/softmax.cc | 2 +- infini_train/src/kernels/cpu/split.cc | 2 +- infini_train/src/kernels/cpu/stack.cc | 2 +- infini_train/src/kernels/cpu/transform.cc | 2 +- .../src/kernels/cuda/accumulate_grad.cu | 2 +- infini_train/src/kernels/cuda/cast.cu | 2 +- infini_train/src/kernels/cuda/comm.cu | 3 +- infini_train/src/kernels/cuda/concat.cu | 2 +- .../src/kernels/cuda/cross_entropy.cu | 7 +- infini_train/src/kernels/cuda/elementwise.cu | 2 +- infini_train/src/kernels/cuda/embedding.cu | 2 +- infini_train/src/kernels/cuda/fill.cu | 2 +- infini_train/src/kernels/cuda/gather.cu | 2 +- infini_train/src/kernels/cuda/layernorm.cu | 2 +- infini_train/src/kernels/cuda/linear.cu | 2 +- infini_train/src/kernels/cuda/no_op.cu | 2 +- infini_train/src/kernels/cuda/outer.cu | 2 +- infini_train/src/kernels/cuda/reduction.cu | 2 +- infini_train/src/kernels/cuda/slice.cu | 2 +- infini_train/src/kernels/cuda/softmax.cu | 2 +- infini_train/src/kernels/cuda/split.cu | 2 +- infini_train/src/kernels/cuda/stack.cu | 2 +- infini_train/src/kernels/cuda/transform.cu | 2 +- .../cuda/vocab_parallel_cross_entropy.cu | 2 +- infini_train/src/nn/init.cc | 16 +-- infini_train/src/nn/modules/linear.cc | 2 +- infini_train/src/nn/modules/module.cc | 5 +- infini_train/src/nn/modules/normalization.cc | 2 +- infini_train/src/nn/modules/sparse.cc | 2 +- infini_train/src/nn/parallel/data_parallel.cc | 7 +- infini_train/src/nn/parallel/ddp/reducer.cc | 8 +- .../src/nn/parallel/pp/pipeline_schedule.cc | 3 +- .../src/nn/parallel/pp/pipeline_stage.cc | 5 +- infini_train/src/nn/parallel/process_group.cc | 4 +- infini_train/src/profiler.cc | 21 ++-- infini_train/src/tensor.cc | 112 +++++++++--------- 70 files changed, 239 insertions(+), 204 deletions(-) diff --git a/example/common/tokenizer.cc b/example/common/tokenizer.cc index d330753f..cfbe4df2 100644 --- a/example/common/tokenizer.cc +++ b/example/common/tokenizer.cc @@ -121,7 +121,7 @@ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_siz uint64_t kRngState = kRngState; LOG(INFO) << "start generate text:"; - const auto *cpu_device = DeviceManager::Instance()->GetDefaultDevice(); + const auto *cpu_device = Device(); for (int t = prompt_len; t < text_length; ++t) { x = std::make_shared(x->To(device)); // CPU->calc device // TODO(jym): use no_grad forward later diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 3219c1f5..a0fac4bf 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -113,7 +113,7 @@ void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; // select the device - const Device *device; + Device device; int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); @@ -138,7 +138,7 @@ void Train(const nn::parallel::Rank &rank) { const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { - device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); + device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), @@ -162,8 +162,7 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::pp_rank = pp_rank; } } else { - device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() - : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); + device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); } // calculate gradient accumulation from the desired total batch size and the current run configuration @@ -212,7 +211,7 @@ void Train(const nn::parallel::Rank &rank) { {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared( - model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(), + model, pp_world_size, num_micro_batches, shapes, pp_rank, device, std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { auto ddp_config @@ -347,7 +346,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward"; - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + auto loss_cpu = loss->To(Device()); lossf += static_cast(loss_cpu.DataPtr())[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward"; loss->Backward(); @@ -369,8 +368,7 @@ void Train(const nn::parallel::Rank &rank) { if (ddp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); - lossf = static_cast( - lossf_tensor->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } const auto iter_end = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index ff9b6660..d0bdcbd1 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -95,7 +95,7 @@ void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; // select the device - const Device *device; + Device device; int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); @@ -119,7 +119,7 @@ void Train(const nn::parallel::Rank &rank) { const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { - device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); + device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), @@ -143,8 +143,7 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::pp_rank = pp_rank; } } else { - device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() - : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); + device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); } // calculate gradient accumulation from the desired total batch size and the current run configuration @@ -191,7 +190,7 @@ void Train(const nn::parallel::Rank &rank) { {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared( - model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(), + model, pp_world_size, num_micro_batches, shapes, pp_rank, device, std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { auto ddp_config @@ -300,7 +299,7 @@ void Train(const nn::parallel::Rank &rank) { for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + infini_train::AutocastGuard autocast_guard(device.type(), dtype); // (bs, seq_len), (bs, seq_len) auto [x, y] = *train_iter; @@ -323,7 +322,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward"; - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + auto loss_cpu = loss->To(Device()); lossf += static_cast(loss_cpu.DataPtr())[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward"; loss->Backward(); @@ -345,8 +344,7 @@ void Train(const nn::parallel::Rank &rank) { if (ddp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); - lossf = static_cast( - lossf_tensor->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } const auto iter_end = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 50f200f8..b92e35a8 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -100,8 +100,7 @@ ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, - const infini_train::Device *device - = DeviceManager::Instance()->GetDefaultDevice()) { + const infini_train::Device *device = Device()) { DataType dtype = DataType::kFLOAT32; CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); diff --git a/example/llama3/net.h b/example/llama3/net.h index 034aa9e8..845d56e6 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -55,8 +55,7 @@ class RMSNorm : public infini_train::nn::CloneableModule { static constexpr char kType[] = "RMSNorm"; static constexpr char kParamWeightName[] = "weight"; - explicit RMSNorm(int64_t dim, float eps = 1e-6f, - const infini_train::Device *device = infini_train::DeviceManager::Instance()->GetDefaultDevice()); + explicit RMSNorm(int64_t dim, float eps = 1e-6f, const infini_train::Device *device = infini_train::Device()); std::vector> Forward(const std::vector> &x) override; diff --git a/example/mnist/main.cc b/example/mnist/main.cc index 097529bf..e62257d7 100644 --- a/example/mnist/main.cc +++ b/example/mnist/main.cc @@ -48,9 +48,8 @@ int main(int argc, char *argv[]) { DataLoader test_dataloader(test_dataset, FLAGS_bs); auto network = MNIST(); - const Device *device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0) - : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); - const Device *cpu_device = DeviceManager::Instance()->GetDefaultDevice(); + Device device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); + Device cpu_device = Device(); network.To(device); auto loss_fn = nn::CrossEntropyLoss(); diff --git a/infini_train/include/autocast.h b/infini_train/include/autocast.h index e5bcf6af..3fb195c0 100644 --- a/infini_train/include/autocast.h +++ b/infini_train/include/autocast.h @@ -91,18 +91,18 @@ inline const std::unordered_map kOpCastPolicyMap = }; // Default autocast data types for each device type -inline constexpr std::array(DeviceType::kCount)> kDeviceDefaultDtype = { +inline constexpr std::array(Device::DeviceType::kCount)> kDeviceDefaultDtype = { DataType::kBFLOAT16, // CPU DataType::kFLOAT16, // CUDA. }; // Thread-local context to track autocast state struct AutocastContext { - bool enabled = false; // Whether autocast is active in the current thread - DeviceType device_type = DeviceType::kCPU; // Target device type (CPU/GPU) - DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting + bool enabled = false; // Whether autocast is active in the current thread + Device::DeviceType device_type = Device::DeviceType::kCPU; // Target device type (CPU/GPU) + DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting - template void Autocast(std::pair key, ArgsT &...args) { + template void Autocast(std::pair key, ArgsT &...args) { if (!enabled) { return; } @@ -172,14 +172,14 @@ inline thread_local AutocastContext tls_autocast_context; // RAII guard to enable/disable autocast in a scope class AutocastGuard { public: - AutocastGuard(DeviceType device_type, DataType autocast_dtype) { + AutocastGuard(Device::DeviceType device_type, DataType autocast_dtype) { saved_context_ = tls_autocast_context; tls_autocast_context.enabled = true; tls_autocast_context.device_type = device_type; tls_autocast_context.autocast_dtype = autocast_dtype; } - AutocastGuard(DeviceType device_type) + AutocastGuard(Device::DeviceType device_type) : AutocastGuard(device_type, kDeviceDefaultDtype[static_cast(device_type)]) {} // Disable autocast (restore previous state) diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h index af37e42a..8f4d62b2 100644 --- a/infini_train/include/core/device_guard.h +++ b/infini_train/include/core/device_guard.h @@ -95,9 +95,9 @@ class DeviceGuardImpl { virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream); - virtual void ResetMemPoolHighWatermarks() const; + virtual void ResetMemPoolHighWatermarks(Device device) const; - virtual std::pair GetMemPoolPeakMB() const; + virtual std::pair GetMemPoolPeakMB(Device device) const; }; // @@ -170,7 +170,7 @@ class DeviceGuard { // INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(). // // Example: -// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(DeviceType::kCUDA, CudaGuardImpl) +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) // class DeviceGuardImplRegistry { public: @@ -188,8 +188,6 @@ class DeviceGuardImplRegistry { std::unordered_map> impls_; }; -DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type); - } // namespace infini_train::core // @@ -200,7 +198,7 @@ DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type); // at static initialization time. // // Example usage: -// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(DeviceType::kCUDA, CudaGuardImpl) +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) // #define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \ static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \ diff --git a/infini_train/include/device.h b/infini_train/include/device.h index d4dd76ff..28db395f 100644 --- a/infini_train/include/device.h +++ b/infini_train/include/device.h @@ -13,14 +13,16 @@ class Device { enum class DeviceType : int8_t { kCPU = 0, kCUDA = 1, + kCount = 2, kInvalid = -1, }; Device(); - Device &operator=(const Device &) = default; Device(DeviceType type, int8_t index); + Device &operator=(const Device &) = default; + ~Device() = default; DeviceType type() const; @@ -35,6 +37,10 @@ class Device { friend std::ostream &operator<<(std::ostream &os, const Device &device); + friend bool operator==(const Device &a, const Device &b); + + friend bool operator!=(const Device &a, const Device &b); + private: DeviceType type_ = DeviceType::kInvalid; int8_t index_ = -1; diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 7b87d59a..9ebe9e2a 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -413,7 +413,7 @@ class KernelFunction { class Dispatcher { public: - using KeyT = std::pair; + using KeyT = std::pair; static Dispatcher &Instance() { static Dispatcher instance; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 398166b5..98dc7e5b 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -7,6 +7,7 @@ #include #include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" namespace infini_train { class Tensor; @@ -72,7 +73,7 @@ class Module : public std::enable_shared_from_this { return 0.0f; }; - virtual void To(const Device *device); + virtual void To(Device device); virtual void To(DataType dtype); @@ -91,7 +92,7 @@ class Module : public std::enable_shared_from_this { std::shared_ptr RegisterBackwardPostHook(ModulePostHook hook); protected: - const Device *device_ = nullptr; + Device device_ = Device(); const std::string type_ = kUndefinedType; std::unordered_map> modules_; std::unordered_map> parameters_; diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index 7a188cd4..b59369dc 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -16,7 +16,7 @@ namespace infini_train::nn::parallel { class PipelineStage { public: - PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, int device_id, + PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, Device device, std::vector> &&chunks); std::vector> ForwardOneChunk(const std::vector> &inputs, diff --git a/infini_train/include/profiler.h b/infini_train/include/profiler.h index bb54bfcc..b0aa1fbe 100644 --- a/infini_train/include/profiler.h +++ b/infini_train/include/profiler.h @@ -17,12 +17,12 @@ inline thread_local int g_profiling_depth = 0; struct ProfileContext { std::string name; - DeviceType device; + Device::DeviceType device; }; inline thread_local ProfileContext g_profile_context; -inline void SetProfileContext(const std::string &name, DeviceType device) { +inline void SetProfileContext(const std::string &name, Device::DeviceType device) { if (g_profiling_depth == 0) { g_profile_context.name = name; g_profile_context.device = device; @@ -63,8 +63,8 @@ class Profiler { static Profiler &Instance(); - void StartRecord(const std::string &name, DeviceType device); - void EndRecord(const std::string &name, DeviceType device); + void StartRecord(const std::string &name, Device::DeviceType device); + void EndRecord(const std::string &name, Device::DeviceType device); void Report(std::ostream &os = std::cout, SortBy sort_by = SortBy::NotSorted) const; void Report(const std::string &file_path, SortBy sort_by = SortBy::NotSorted) const; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 6ff3fa64..b3a9b042 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -38,17 +38,17 @@ struct PrintOptions { class TensorBuffer { public: - TensorBuffer(const Device *device, size_t size); + TensorBuffer(Device device, size_t size); ~TensorBuffer(); void *DataPtr(); const void *DataPtr() const; - const Device *GetDevice() const; + Device GetDevice() const; size_t Size() const; private: - const Device *device_ = nullptr; + Device device_ = Device(); size_t size_ = 0; void *data_ = nullptr; }; @@ -57,19 +57,17 @@ class Tensor : public std::enable_shared_from_this { public: Tensor() = default; - Tensor(const std::vector &dims, DataType dtype, const Device *device); - Tensor(const std::vector &dims, DataType dtype) - : Tensor(dims, dtype, DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0)) {} + Tensor(const std::vector &dims, DataType dtype, Device device); + Tensor(const std::vector &dims, DataType dtype) : Tensor(dims, dtype, Device()) {} Tensor(const Tensor &tensor, size_t offset, const std::vector &dims); void SetData(const Tensor &tensor, size_t offset, bool preserve_data = false); - Tensor(const float *data, const std::vector &dims, DataType dtype, const Device *device); - Tensor(const float *data, const std::vector &dims, DataType dtype) - : Tensor(data, dims, dtype, DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0)) {} + Tensor(const float *data, const std::vector &dims, DataType dtype, Device device); + Tensor(const float *data, const std::vector &dims, DataType dtype) : Tensor(data, dims, dtype, Device()) {} - const Device *GetDevice() const; + Device GetDevice() const; void *DataPtr(); const void *DataPtr() const; @@ -86,7 +84,7 @@ class Tensor : public std::enable_shared_from_this { Eigen::Map> EigenVector(); // TODO(dcj): return shared_ptr instead of Tensor later - Tensor To(const Device *device); + Tensor To(Device device); Tensor To(DataType dtype); void CopyFrom(const Tensor &src); diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index 1bcad973..48e70f09 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -38,7 +38,7 @@ Gather::Gather(const Device *target_device, int64_t dim, const infini_train::nn: std::vector> Gather::Forward(const std::vector> &input_tensors) { for (const auto &tensor : input_tensors) { - CHECK_NE(static_cast(tensor->GetDevice()->Type()), static_cast(DeviceType::kCPU)) + CHECK_NE(static_cast(tensor->GetDevice()->Type()), static_cast(Device::DeviceType::kCPU)) << "Gather function not implemented for CPU tensors"; } if (dim_ == 0 && input_tensors[0]->Dims().size() == 0) { diff --git a/infini_train/src/core/cuda/cuda_guard.cc b/infini_train/src/core/cuda/cuda_guard.cc index ae0b34ef..4ff5ecf0 100644 --- a/infini_train/src/core/cuda/cuda_guard.cc +++ b/infini_train/src/core/cuda/cuda_guard.cc @@ -15,6 +15,7 @@ namespace infini_train::core::cuda { namespace { constexpr int kMaxGpus = 8; +constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; static std::array, kMaxGpus> cuda_streams; static std::array, kMaxGpus> cuda_blas_handles; @@ -126,6 +127,32 @@ void CudaGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, Memcpy } } +void CudaGuardImpl::ResetMemPoolHighWatermarks(Device device) const { + SetDevice(device); + cudaMemPool_t pool; + CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, device.index())); + + cuuint64_t zero = 0; + // High watermark can only be reset to zero; non-zero is illegal. + CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &zero)); + CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &zero)); +} + +std::pair CudaGuardImpl::GetMemPoolPeakMB(Device device) const { + SetDevice(device); + cudaMemPool_t pool; + CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, device.index())); + + cuuint64_t used = 0; + CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &used)); + + cuuint64_t reserved = 0; + CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &reserved)); + + return std::make_pair(static_cast(used / kBytesPerMB), + static_cast(reserved / kBytesPerMB)); +} + INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) } // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_guard.h b/infini_train/src/core/cuda/cuda_guard.h index e8360025..400bb0da 100644 --- a/infini_train/src/core/cuda/cuda_guard.h +++ b/infini_train/src/core/cuda/cuda_guard.h @@ -49,6 +49,10 @@ class CudaGuardImpl : public DeviceGuardImpl { void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) override; void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) override; + + void ResetMemPoolHighWatermarks(Device device) const override; + + std::pair GetMemPoolPeakMB(Device device) const override; }; } // namespace infini_train::core::cuda diff --git a/infini_train/src/core/device_guard.cc b/infini_train/src/core/device_guard.cc index fb210d75..048d97e0 100644 --- a/infini_train/src/core/device_guard.cc +++ b/infini_train/src/core/device_guard.cc @@ -11,6 +11,11 @@ #include "infini_train/src/core/cpu/cpu_guard.h" namespace infini_train::core { +namespace { +inline DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type) { + return DeviceGuardImplRegistry::Instance().Get(type); +} +} // namespace // DeviceGuardImpl void DeviceGuardImpl::SetDevice(Device device) const { @@ -54,14 +59,16 @@ void DeviceGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, Memc Memcpy(dst, src, count, kind); } -void DeviceGuardImpl::ResetMemPoolHighWatermarks() const { - LOG(WARNING) << "ResetMemPoolHighWatermarks is not supported for this device. " - "The call is ignored."; +void DeviceGuardImpl::ResetMemPoolHighWatermarks(Device device) const { + LOG(WARNING) << std::format("ResetMemPoolHighWatermarks is not supported for device type {} (index {}). " + "The call is ignored.", + static_cast(device.type()), device.index()); } -std::pair DeviceGuardImpl::GetMemPoolPeakMB() const { - LOG(WARNING) << "GetMemPoolPeakMB is not supported for this device. " - "Returning {0, 0}."; +std::pair DeviceGuardImpl::GetMemPoolPeakMB(Device device) const { + LOG(WARNING) << std::format("GetMemPoolPeakMB is not supported for device type {} (index {}). " + "Returning {{0, 0}}.", + static_cast(device.type()), device.index()); return {0, 0}; } diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index faab8e6c..1bb3aaad 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -10,9 +10,7 @@ #include "infini_train/include/nn/parallel/global.h" namespace infini_train { -namespace { -constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; -} // namespace +Device::Device() : type_(DeviceType::kCPU), index_(0) {} Device::Device(DeviceType type, int8_t index) : type_(type), index_(index) { if (type_ == DeviceType::kCPU && index_ != 0) { @@ -44,4 +42,8 @@ std::ostream &operator<<(std::ostream &os, const Device &device) { return os; } +bool operator==(const Device &a, const Device &b) { return a.type_ == b.type_ && a.index_ == b.index_; } + +bool operator!=(const Device &a, const Device &b) { return !(a == b); } + } // namespace infini_train diff --git a/infini_train/src/kernels/cpu/accumulate_grad.cc b/infini_train/src/kernels/cpu/accumulate_grad.cc index 171d722c..cfe85b9c 100644 --- a/infini_train/src/kernels/cpu/accumulate_grad.cc +++ b/infini_train/src/kernels/cpu/accumulate_grad.cc @@ -37,7 +37,7 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p } // namespace infini_train::kernels::cpu #define REGISTER_CPU_ACCUMULATE_GRAD_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_ACCUMULATE_GRAD_KERNEL(AccumulateGrad) REGISTER_CPU_ACCUMULATE_GRAD_KERNEL(AdamAccumulateGrad) diff --git a/infini_train/src/kernels/cpu/cast.cc b/infini_train/src/kernels/cpu/cast.cc index 8481eb15..35f31214 100644 --- a/infini_train/src/kernels/cpu/cast.cc +++ b/infini_train/src/kernels/cpu/cast.cc @@ -24,7 +24,7 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { } // namespace infini_train::kernels::cpu #define REGISTER_CPU_CAST_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_CAST_KERNEL(Cast) diff --git a/infini_train/src/kernels/cpu/concat.cc b/infini_train/src/kernels/cpu/concat.cc index d294eb85..b421063f 100644 --- a/infini_train/src/kernels/cpu/concat.cc +++ b/infini_train/src/kernels/cpu/concat.cc @@ -128,7 +128,7 @@ std::vector> ConcatBackward(const std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu } // namespace infini_train::kernels::cpu #define REGISTER_CPU_CROSS_ENTROPY_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_CROSS_ENTROPY_KERNEL(CrossEntropyForward) REGISTER_CPU_CROSS_ENTROPY_KERNEL(CrossEntropyBackward) diff --git a/infini_train/src/kernels/cpu/elementwise.cc b/infini_train/src/kernels/cpu/elementwise.cc index 608172b6..8d66acd2 100644 --- a/infini_train/src/kernels/cpu/elementwise.cc +++ b/infini_train/src/kernels/cpu/elementwise.cc @@ -313,7 +313,7 @@ std::pair, std::shared_ptr> DivBackward(const st } // namespace infini_train::kernels::cpu #define REGISTER_CPU_ELEMENTWISE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_ELEMENTWISE_KERNEL(NegForward) REGISTER_CPU_ELEMENTWISE_KERNEL(NegBackward) diff --git a/infini_train/src/kernels/cpu/embedding.cc b/infini_train/src/kernels/cpu/embedding.cc index 5debac9f..190c77c5 100644 --- a/infini_train/src/kernels/cpu/embedding.cc +++ b/infini_train/src/kernels/cpu/embedding.cc @@ -56,7 +56,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, } // namespace infini_train::kernels::cpu #define REGISTER_CPU_EMBEDDING_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_EMBEDDING_KERNEL(EmbeddingForward) REGISTER_CPU_EMBEDDING_KERNEL(EmbeddingBackward) diff --git a/infini_train/src/kernels/cpu/fill.cc b/infini_train/src/kernels/cpu/fill.cc index 2e8fdbc7..175a15a2 100644 --- a/infini_train/src/kernels/cpu/fill.cc +++ b/infini_train/src/kernels/cpu/fill.cc @@ -12,7 +12,7 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { } // namespace infini_train::kernels::cpu #define REGISTER_CPU_FILL_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_FILL_KERNEL(Fill) diff --git a/infini_train/src/kernels/cpu/gather.cc b/infini_train/src/kernels/cpu/gather.cc index c612efaa..9717b795 100644 --- a/infini_train/src/kernels/cpu/gather.cc +++ b/infini_train/src/kernels/cpu/gather.cc @@ -197,7 +197,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ } // namespace infini_train::kernels::cpu #define REGISTER_CPU_GATHER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_GATHER_KERNEL(IndexGatherForward) REGISTER_CPU_GATHER_KERNEL(IndexGatherBackward) diff --git a/infini_train/src/kernels/cpu/layernorm.cc b/infini_train/src/kernels/cpu/layernorm.cc index d717f348..c587f2c5 100644 --- a/infini_train/src/kernels/cpu/layernorm.cc +++ b/infini_train/src/kernels/cpu/layernorm.cc @@ -139,7 +139,7 @@ LayerNormBackward(const std::shared_ptr &input, const std::shared_ptr &input, const std::shared_ptr NoOpBackward(const std::vector &dims, const std } // namespace infini_train::kernels::cpu #define REGISTER_CPU_NO_OP_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_NO_OP_KERNEL(NoOpForward) REGISTER_CPU_NO_OP_KERNEL(NoOpBackward) diff --git a/infini_train/src/kernels/cpu/outer.cc b/infini_train/src/kernels/cpu/outer.cc index 2991dfd3..b61a3ed0 100644 --- a/infini_train/src/kernels/cpu/outer.cc +++ b/infini_train/src/kernels/cpu/outer.cc @@ -59,7 +59,7 @@ std::tuple, std::shared_ptr> OuterBackward(const } // namespace infini_train::kernels::cpu #define REGISTER_CPU_OUTER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_OUTER_KERNEL(OuterForward) REGISTER_CPU_OUTER_KERNEL(OuterBackward) diff --git a/infini_train/src/kernels/cpu/reduction.cc b/infini_train/src/kernels/cpu/reduction.cc index 87ed5384..0aa936ba 100644 --- a/infini_train/src/kernels/cpu/reduction.cc +++ b/infini_train/src/kernels/cpu/reduction.cc @@ -169,7 +169,7 @@ std::shared_ptr MinBackward(const std::shared_ptr &grad_output, } // namespace infini_train::kernels::cpu #define REGISTER_CPU_REDUCTION_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_REDUCTION_KERNEL(MeanForward) REGISTER_CPU_REDUCTION_KERNEL(MeanBackward) diff --git a/infini_train/src/kernels/cpu/sigmoid.cc b/infini_train/src/kernels/cpu/sigmoid.cc index d4bc05da..8163a096 100644 --- a/infini_train/src/kernels/cpu/sigmoid.cc +++ b/infini_train/src/kernels/cpu/sigmoid.cc @@ -35,7 +35,7 @@ std::shared_ptr SigmoidBackward(const std::shared_ptr &output, } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SIGMOID_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SIGMOID_KERNEL(SigmoidForward) REGISTER_CPU_SIGMOID_KERNEL(SigmoidBackward) diff --git a/infini_train/src/kernels/cpu/slice.cc b/infini_train/src/kernels/cpu/slice.cc index 943b1c1b..bef925a7 100644 --- a/infini_train/src/kernels/cpu/slice.cc +++ b/infini_train/src/kernels/cpu/slice.cc @@ -130,7 +130,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SLICE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SLICE_KERNEL(SliceForward) REGISTER_CPU_SLICE_KERNEL(SliceBackward) diff --git a/infini_train/src/kernels/cpu/softmax.cc b/infini_train/src/kernels/cpu/softmax.cc index 454bdc2d..f711fbdc 100644 --- a/infini_train/src/kernels/cpu/softmax.cc +++ b/infini_train/src/kernels/cpu/softmax.cc @@ -81,7 +81,7 @@ std::shared_ptr SoftmaxBackward(const std::shared_ptr &grad_outp } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SOFTMAX_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SOFTMAX_KERNEL(SoftmaxForward) REGISTER_CPU_SOFTMAX_KERNEL(SoftmaxBackward) diff --git a/infini_train/src/kernels/cpu/split.cc b/infini_train/src/kernels/cpu/split.cc index e9a90ea9..209857f0 100644 --- a/infini_train/src/kernels/cpu/split.cc +++ b/infini_train/src/kernels/cpu/split.cc @@ -74,7 +74,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SPLIT_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SPLIT_KERNEL(SplitForward) REGISTER_CPU_SPLIT_KERNEL(SplitBackward) diff --git a/infini_train/src/kernels/cpu/stack.cc b/infini_train/src/kernels/cpu/stack.cc index d1f71ed2..0ada6475 100644 --- a/infini_train/src/kernels/cpu/stack.cc +++ b/infini_train/src/kernels/cpu/stack.cc @@ -81,7 +81,7 @@ std::vector> StackBackward(const std::vector &i } // namespace infini_train::kernels::cpu #define REGISTER_CPU_STACK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_STACK_KERNEL(StackForward) REGISTER_CPU_STACK_KERNEL(StackBackward) diff --git a/infini_train/src/kernels/cpu/transform.cc b/infini_train/src/kernels/cpu/transform.cc index 1c1697b0..00387917 100644 --- a/infini_train/src/kernels/cpu/transform.cc +++ b/infini_train/src/kernels/cpu/transform.cc @@ -219,7 +219,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & } // namespace infini_train::kernels::cpu #define REGISTER_CPU_TRANSFORM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_TRANSFORM_KERNEL(TrilForward) REGISTER_CPU_TRANSFORM_KERNEL(TrilBackward) diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index 003ee6f4..c922cb35 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -80,7 +80,7 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(AccumulateGrad) REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(AdamAccumulateGrad) diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 6b53e8c8..0feb6dae 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -43,7 +43,7 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_CAST_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_CAST_KERNEL(Cast) diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index c84cc068..6fc7adeb 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -71,7 +71,8 @@ std::shared_ptr Gather(const std::vector> &tenso } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_COMM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, Comm##kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Comm##kernel_name, \ + infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_COMM_KERNEL(Broadcast) REGISTER_CUDA_COMM_KERNEL(Scatter) diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index f3d1730c..6beba608 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -232,7 +232,7 @@ std::vector> ConcatBackward(const std::shared_ptr CrossEntropyForward(const std::shared_ptr &input <<Stream()>>>(input_ptr, target_ptr, batched_loss_ptr, bs, num_classes); - auto loss_cpu = batched_output->To(DeviceManager::Instance()->GetDefaultDevice()); - auto loss = std::make_shared(std::vector{}, input->Dtype(), - DeviceManager::Instance()->GetDefaultDevice()); + auto loss_cpu = batched_output->To(Device()); + auto loss = std::make_shared(std::vector{}, input->Dtype(), Device()); auto loss_cpu_typed_ptr = static_cast(loss_cpu.DataPtr()); static_cast(loss->DataPtr())[0] = std::accumulate(loss_cpu_typed_ptr, loss_cpu_typed_ptr + bs, 0.0f, @@ -207,7 +206,7 @@ std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_CROSS_ENTROPY_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_CROSS_ENTROPY_KERNEL(CrossEntropyForward) REGISTER_CUDA_CROSS_ENTROPY_KERNEL(CrossEntropyBackward) diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 913d848b..8bbc72fc 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -1081,7 +1081,7 @@ std::shared_ptr SigmoidBackward(const std::shared_ptr &output, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_ELEMENTWISE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_ELEMENTWISE_KERNEL(NegForward) REGISTER_CUDA_ELEMENTWISE_KERNEL(NegBackward) diff --git a/infini_train/src/kernels/cuda/embedding.cu b/infini_train/src/kernels/cuda/embedding.cu index 6ae904f5..f43239b2 100644 --- a/infini_train/src/kernels/cuda/embedding.cu +++ b/infini_train/src/kernels/cuda/embedding.cu @@ -105,7 +105,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_EMBEDDING_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_EMBEDDING_KERNEL(EmbeddingForward) REGISTER_CUDA_EMBEDDING_KERNEL(EmbeddingBackward) diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index 2a601032..4a5d2f45 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -32,7 +32,7 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_FILL_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_FILL_KERNEL(Fill) diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index cc90d4a5..47d63478 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -216,7 +216,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_GATHER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_GATHER_KERNEL(IndexGatherForward) REGISTER_CUDA_GATHER_KERNEL(IndexGatherBackward) diff --git a/infini_train/src/kernels/cuda/layernorm.cu b/infini_train/src/kernels/cuda/layernorm.cu index ae825441..70d9932f 100644 --- a/infini_train/src/kernels/cuda/layernorm.cu +++ b/infini_train/src/kernels/cuda/layernorm.cu @@ -189,7 +189,7 @@ LayerNormBackward(const std::shared_ptr &input, const std::shared_ptr &input, const std::shared_ptr NoOpBackward(const std::vector &dims, const std } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_NO_OP_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_NO_OP_KERNEL(NoOpForward) REGISTER_CUDA_NO_OP_KERNEL(NoOpBackward) diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index a0bcfe19..e9716072 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -152,7 +152,7 @@ std::tuple, std::shared_ptr> OuterBackward(const } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_OUTER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_OUTER_KERNEL(OuterForward) REGISTER_CUDA_OUTER_KERNEL(OuterBackward) diff --git a/infini_train/src/kernels/cuda/reduction.cu b/infini_train/src/kernels/cuda/reduction.cu index 9c7ff9d7..5d8f2c15 100644 --- a/infini_train/src/kernels/cuda/reduction.cu +++ b/infini_train/src/kernels/cuda/reduction.cu @@ -218,7 +218,7 @@ std::shared_ptr MinBackward(const std::shared_ptr &grad_output, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_REDUCTION_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_REDUCTION_KERNEL(MeanForward) REGISTER_CUDA_REDUCTION_KERNEL(SumForward) diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 38d4aab6..032ebc47 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -194,7 +194,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_SLICE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_SLICE_KERNEL(SliceForward) REGISTER_CUDA_SLICE_KERNEL(SliceBackward) diff --git a/infini_train/src/kernels/cuda/softmax.cu b/infini_train/src/kernels/cuda/softmax.cu index 98d47fae..0184dc7a 100644 --- a/infini_train/src/kernels/cuda/softmax.cu +++ b/infini_train/src/kernels/cuda/softmax.cu @@ -208,7 +208,7 @@ std::shared_ptr SoftmaxBackward(const std::shared_ptr &grad_outp } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_SOFTMAX_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_SOFTMAX_KERNEL(SoftmaxForward) REGISTER_CUDA_SOFTMAX_KERNEL(SoftmaxBackward) diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index ab22bf95..5b2c4838 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -165,7 +165,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_SPLIT_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_SPLIT_KERNEL(SplitForward) REGISTER_CUDA_SPLIT_KERNEL(SplitBackward) diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 5fe4899c..cef9a05f 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -145,7 +145,7 @@ std::vector> StackBackward(const std::vector &i } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_STACK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_STACK_KERNEL(StackForward) REGISTER_CUDA_STACK_KERNEL(StackBackward) diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 7f1f818d..9a7cee41 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -545,7 +545,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_TRANSFORM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_TRANSFORM_KERNEL(TrilForward) REGISTER_CUDA_TRANSFORM_KERNEL(TrilBackward) diff --git a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu index 8b5d4450..75f0206c 100644 --- a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu +++ b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu @@ -112,7 +112,7 @@ VocabParallelCrossEntropyBackward(const std::shared_ptr &grad_output, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_VOCAB_PARALLEL_CROSS_ENTROPY_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_VOCAB_PARALLEL_CROSS_ENTROPY_KERNEL(VocabParallelCrossEntropyBackward) diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index e00e2f8a..35fe3830 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -49,12 +49,12 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean device->SetDevice(); switch (device->Type()) { - case DeviceType::kCPU: { + case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; } #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { // TODO(dcj): maybe use async API later? cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream()); @@ -155,12 +155,12 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, device->SetDevice(); switch (device->Type()) { - case DeviceType::kCPU: { + case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; } #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { // TODO(dcj): maybe use async API later? cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream()); @@ -185,12 +185,12 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { device->SetDevice(); switch (device->Type()) { - case DeviceType::kCPU: { + case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; } #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { // TODO(dcj): maybe use async API later? cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream()); @@ -215,12 +215,12 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { device->SetDevice(); switch (device->Type()) { - case DeviceType::kCPU: { + case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; } #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { // TODO(dcj): maybe use async API later? cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream()); diff --git a/infini_train/src/nn/modules/linear.cc b/infini_train/src/nn/modules/linear.cc index 67c0d733..b6822cc1 100644 --- a/infini_train/src/nn/modules/linear.cc +++ b/infini_train/src/nn/modules/linear.cc @@ -12,7 +12,7 @@ namespace infini_train::nn { Linear::Linear(int64_t in_features, int64_t out_features, bool bias, const Device *device) : CloneableModule(kType), bias_(bias) { - device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); + device_ = device ? device : Device(); parameters_[kParamWeightName] = std::make_shared(std::vector{out_features, in_features}, DataType::kFLOAT32, device_) diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 0ac1165c..1b764ed4 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -23,7 +23,7 @@ namespace infini_train::nn { Module::Module() : Module(kUndefinedType) {} -Module::Module(const std::string &type) : type_(type), device_(DeviceManager::Instance()->GetDefaultDevice()) {} +Module::Module(const std::string &type) : type_(type), device_(Device()) {} const std::string &Module::type() const { return type_; } @@ -223,8 +223,7 @@ std::vector> Module::operator()(const std::vector &normalized_shape, float eps, const Device *device) : CloneableModule(kType), eps_(eps) { - device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); + device_ = device ? device : Device(); parameters_[kParamWeightName] = std::make_shared(normalized_shape, DataType::kFLOAT32, device_)->RequiresGrad(); diff --git a/infini_train/src/nn/modules/sparse.cc b/infini_train/src/nn/modules/sparse.cc index ab845697..2fdeafb8 100644 --- a/infini_train/src/nn/modules/sparse.cc +++ b/infini_train/src/nn/modules/sparse.cc @@ -11,7 +11,7 @@ namespace infini_train::nn { Embedding::Embedding(int num_embeddings, int embedding_dim, const Device *device) : CloneableModule(kType) { - device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); + device_ = device ? device : Device(); parameters_[kParamWeightName] = std::make_shared(std::vector{num_embeddings, embedding_dim}, DataType::kFLOAT32, device_) diff --git a/infini_train/src/nn/parallel/data_parallel.cc b/infini_train/src/nn/parallel/data_parallel.cc index 1a64ab8a..b68836a4 100644 --- a/infini_train/src/nn/parallel/data_parallel.cc +++ b/infini_train/src/nn/parallel/data_parallel.cc @@ -10,6 +10,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/tensor.h" @@ -57,8 +58,10 @@ ParallelApply(const std::vector> &modules, } } // namespace -DataParallel::DataParallel(const std::shared_ptr &module, int dim) - : dim_(dim), devices_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA)) { +DataParallel::DataParallel(const std::shared_ptr &module, int dim, Device::DeviceType device_type) : dim_(dim) { + devices_.reserve(global::GetNthreadPerProc()); + for (int index = 0; index < global::GetNthreadPerProc(); ++index) { devices_.emplace_back(device_type, index); } + CHECK_GT(devices_.size(), 0) << "No available devices found"; output_device_ = devices_.at(0); src_device_ = devices_.at(0); diff --git a/infini_train/src/nn/parallel/ddp/reducer.cc b/infini_train/src/nn/parallel/ddp/reducer.cc index d32ca668..092fa74b 100644 --- a/infini_train/src/nn/parallel/ddp/reducer.cc +++ b/infini_train/src/nn/parallel/ddp/reducer.cc @@ -27,12 +27,12 @@ void CopyGradToBucket(const std::shared_ptr &grad, const std::shared_ptr const void *src = grad->DataPtr(); const auto dev_type = grad->GetDevice()->Type(); - if (dev_type == DeviceType::kCPU) { + if (dev_type == Device::DeviceType::kCPU) { std::memcpy(dst, src, bytes); return; } #ifdef USE_CUDA - if (dev_type == DeviceType::kCUDA) { + if (dev_type == Device::DeviceType::kCUDA) { auto *cuda_dev = dynamic_cast(flat->GetDevice()); CHECK(cuda_dev); cuda_dev->SetDevice(); @@ -53,12 +53,12 @@ void CopyBucketToGrad(const std::shared_ptr &flat, const std::shared_ptr void *dst = grad->DataPtr(); const auto dev_type = grad->GetDevice()->Type(); - if (dev_type == DeviceType::kCPU) { + if (dev_type == Device::DeviceType::kCPU) { std::memcpy(dst, src, bytes); return; } #ifdef USE_CUDA - if (dev_type == DeviceType::kCUDA) { + if (dev_type == Device::DeviceType::kCUDA) { auto *cuda_dev = dynamic_cast(flat->GetDevice()); CHECK(cuda_dev); cuda_dev->SetDevice(); diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 1d235901..7496017b 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -248,8 +248,7 @@ float PipelineSchedule::StepMicroBatches(const std::vector(target_on_device)})[0]; loss = loss / n; } - total_loss - += static_cast(loss->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + total_loss += static_cast(loss->To(Device()).DataPtr())[0]; loss->Backward(); } else { diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index 1c13a001..6f02662c 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -11,11 +11,10 @@ namespace infini_train::nn::parallel { PipelineStage::PipelineStage(int stage_index /* pp_rank */, int num_stages /* pp_size */, - const std::vector> &recv_shape, int device_id, + const std::vector> &recv_shape, Device device, std::vector> &&chunks) : stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), - next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), - device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)), + next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), device_(device), chunks_(std::move(chunks)) {} std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs, diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 50a75d48..45a3eac7 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -128,7 +128,7 @@ void ProcessGroupNCCL::InitSingleProcess(const std::vector &ranks) { NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); for (int i = 0; i < ranks.size(); ++i) { - auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, ranks[i]); + auto device = Device(Device::DeviceType::kCUDA, ranks[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; global_group_rank_map_[device->rank().GlobalRank()] = i; @@ -165,7 +165,7 @@ void ProcessGroupNCCL::InitMultiProcess(const std::vector &ranks) { NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank)); comms_.push_back(comm); - auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i); + auto device = Device(Device::DeviceType::kCUDA, i); global_group_rank_map_[device->rank().GlobalRank()] = group_rank; devices_.push_back(device); device_comm_map_[device] = comm; diff --git a/infini_train/src/profiler.cc b/infini_train/src/profiler.cc index f2be2f4b..6464c24c 100644 --- a/infini_train/src/profiler.cc +++ b/infini_train/src/profiler.cc @@ -38,8 +38,8 @@ Profiler &Profiler::Instance() { return profiler; } -int GetRank(DeviceType device) { - if (device == DeviceType::kCPU) { +int GetRank(Device::DeviceType device) { + if (device == Device::DeviceType::kCPU) { return 0; } @@ -53,25 +53,24 @@ int GetRank(DeviceType device) { #ifdef USE_CUDA cudaStream_t GetCudaStream() { - int device_id = GetRank(DeviceType::kCUDA); + int device_id = GetRank(Device::DeviceType::kCUDA); // TODO(zbl): support multi-stream on single device - return dynamic_cast( - DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, static_cast(device_id))) + return dynamic_cast(Device(Device::DeviceType::kCUDA, static_cast(device_id))) ->Stream(); } #endif -void Profiler::StartRecord(const std::string &name, DeviceType device) { +void Profiler::StartRecord(const std::string &name, Device::DeviceType device) { if (g_profiling_depth++ > 0) { return; } cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); switch (device) { - case DeviceType::kCPU: + case Device::DeviceType::kCPU: break; #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { auto it = cuda_timing_map_.find(name); if (it != cuda_timing_map_.end()) { // Make sure there are no conflicts @@ -100,7 +99,7 @@ void Profiler::StartRecord(const std::string &name, DeviceType device) { } } -void Profiler::EndRecord(const std::string &name, DeviceType device) { +void Profiler::EndRecord(const std::string &name, Device::DeviceType device) { if (--g_profiling_depth > 0) { return; } @@ -110,10 +109,10 @@ void Profiler::EndRecord(const std::string &name, DeviceType device) { int rank = GetRank(device); switch (device) { - case DeviceType::kCPU: + case Device::DeviceType::kCPU: break; #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { auto it = cuda_timing_map_.find(name); if (it != cuda_timing_map_.end()) { auto event_pair = it->second; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 05257953..cb351d03 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -31,14 +31,14 @@ #include "infini_train/include/nn/init.h" namespace infini_train { -TensorBuffer::TensorBuffer(const Device *device, size_t size) : device_(device), size_(size) { +TensorBuffer::TensorBuffer(Device device, size_t size) : device_(device), size_(size) { CHECK_NOTNULL(device); - switch (device_->Type()) { - case DeviceType::kCPU: + switch (device.type()) { + case Device::DeviceType::kCPU: data_ = malloc(size); break; #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); // TODO(dcj): Maybe pin memory later. @@ -50,23 +50,23 @@ TensorBuffer::TensorBuffer(const Device *device, size_t size) : device_(device), } #endif default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device_->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(device_.type()); break; } } TensorBuffer::~TensorBuffer() { - switch (device_->Type()) { - case DeviceType::kCPU: + switch (device_.type()) { + case Device::DeviceType::kCPU: free(data_); break; #ifdef USE_CUDA - case DeviceType::kCUDA: + case Device::DeviceType::kCUDA: CUDA_CHECK(cudaFreeAsync(data_, dynamic_cast(device_)->Stream())); break; #endif default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device_->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(device_.type()); break; } } @@ -75,12 +75,12 @@ void *TensorBuffer::DataPtr() { return data_; } const void *TensorBuffer::DataPtr() const { return data_; } -const Device *TensorBuffer::GetDevice() const { return device_; } +Device TensorBuffer::GetDevice() const { return device_; } size_t TensorBuffer::Size() const { return size_; } // Tensor implementation -Tensor::Tensor(const std::vector &dims, DataType dtype, const Device *device) : dims_(dims), dtype_(dtype) { +Tensor::Tensor(const std::vector &dims, DataType dtype, Device device) : dims_(dims), dtype_(dtype) { num_elements_ = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); buffer_ = std::make_shared(device, kDataTypeToSize.at(dtype) * num_elements_); } @@ -91,25 +91,25 @@ Tensor::Tensor(const Tensor &tensor, size_t offset, const std::vector & CHECK_LE(offset_ + kDataTypeToSize.at(dtype_) * num_elements_, buffer_->Size()); } -Tensor::Tensor(const float *data, const std::vector &dims, DataType dtype, const Device *device) +Tensor::Tensor(const float *data, const std::vector &dims, DataType dtype, Device device) : dims_(dims), dtype_(dtype), num_elements_(std::accumulate(dims.begin(), dims.end(), 1, std::multiplies())) { // TODO(dcj): support more datatype CHECK(dtype == DataType::kFLOAT32); buffer_ = std::make_shared(device, kDataTypeToSize.at(dtype) * num_elements_); - switch (device->Type()) { - case DeviceType::kCPU: + switch (device.type()) { + case Device::DeviceType::kCPU: memcpy(buffer_->DataPtr(), data, buffer_->Size()); break; #ifdef USE_CUDA - case DeviceType::kCUDA: + case Device::DeviceType::kCUDA: CUDA_CHECK(cudaMemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream())); break; #endif default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); } } @@ -129,7 +129,7 @@ void Tensor::SetData(const Tensor &tensor, size_t offset, bool preserve_data) { offset_ = tensor.offset_ + offset; } -const Device *Tensor::GetDevice() const { return buffer_->GetDevice(); } +Device Tensor::GetDevice() const { return buffer_->GetDevice(); } void *Tensor::DataPtr() { return reinterpret_cast(buffer_->DataPtr()) + offset_; } @@ -156,7 +156,7 @@ template void Tensor::Fill(T value) { std::memcpy((void *)(&storage), &casted_value, sizeof(TargetT)); }); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "Fill"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "Fill"}); kernel.Call(shared_from_this(), static_cast(&storage)); } @@ -187,7 +187,7 @@ Eigen::Map> Tensor::Eig dims_[0]); } -Tensor Tensor::To(const Device *device) { +Tensor Tensor::To(Device device) { if (device == buffer_->GetDevice()) { auto new_tensor = Tensor(*this, offset_, dims_); if (grad_) { @@ -197,29 +197,29 @@ Tensor Tensor::To(const Device *device) { } Tensor new_tensor; - switch (device->Type()) { + switch (device.type()) { #ifdef USE_CUDA - case DeviceType::kCPU: { + case Device::DeviceType::kCPU: { // CUDA -> CPU GetDevice()->SetDevice(); - new_tensor = Tensor(dims_, dtype_, DeviceManager::Instance()->GetDefaultDevice()); + new_tensor = Tensor(dims_, dtype_, Device()); CUDA_CHECK(cudaMemcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), cudaMemcpyDeviceToHost)); break; } - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); new_tensor = Tensor(dims_, dtype_, device); - if (GetDevice()->Type() == DeviceType::kCPU) { + if (GetDevice().type() == Device::DeviceType::kCPU) { device->SetDevice(); // CPU -> CUDA CUDA_CHECK(cudaMemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream())); } else { - // CUDA -> CUDA + // p2p // 1. CUDA -> CPU // 2. CPU -> CUDA - Tensor cpu_tensor = To(DeviceManager::Instance()->GetDefaultDevice()); + Tensor cpu_tensor = To(Device()); device->SetDevice(); CUDA_CHECK(cudaMemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), cudaMemcpyHostToDevice, dynamic_cast(device)->Stream())); @@ -229,7 +229,7 @@ Tensor Tensor::To(const Device *device) { } #endif default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); } if (grad_) { @@ -253,7 +253,7 @@ Tensor Tensor::To(DataType dtype) { auto device = GetDevice(); device->SetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "Cast"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "Cast"}); auto new_tensor = *kernel.Call>(shared_from_this(), dtype); if (grad_) { @@ -272,18 +272,18 @@ void Tensor::CopyFrom(const Tensor &src) { CHECK(Dims() == src.Dims()) << "Tensor::CopyFrom shape mismatch"; const size_t nbytes = SizeInBytes(); - const Device *dst_dev = GetDevice(); - const Device *src_dev = src.GetDevice(); + const Device dst_dev = GetDevice(); + const Device src_dev = src.GetDevice(); - switch (dst_dev->Type()) { - case DeviceType::kCPU: { - switch (src_dev->Type()) { - case DeviceType::kCPU: { + switch (dst_dev.type()) { + case Device::DeviceType::kCPU: { + switch (src_dev.type()) { + case Device::DeviceType::kCPU: { std::memcpy(DataPtr(), src.DataPtr(), nbytes); break; } #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { // CUDA -> CPU CUDA_CHECK(cudaMemcpy(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToHost)); break; @@ -296,38 +296,38 @@ void Tensor::CopyFrom(const Tensor &src) { } #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); dst_dev->SetDevice(); const auto *dst_cuda = dynamic_cast(dst_dev); - switch (src_dev->Type()) { - case DeviceType::kCPU: { + switch (src_dev.type()) { + case Device::DeviceType::kCPU: { // CPU -> CUDA CUDA_CHECK(cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyHostToDevice, dst_cuda->Stream())); break; } - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { const auto *src_cuda = dynamic_cast(src_dev); - if (src_cuda->Index() == dst_cuda->Index()) { + if (src_cuda.index() == dst_cuda.index()) { CUDA_CHECK( cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToDevice, dst_cuda->Stream())); } else { int canAccessPeer = 0; - CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, dst_cuda->Index(), src_cuda->Index())); + CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, dst_cuda.index(), src_cuda.index())); if (canAccessPeer) { - CUDA_CHECK(cudaMemcpyPeerAsync(DataPtr(), dst_cuda->Index(), src.DataPtr(), src_cuda->Index(), - nbytes, dst_cuda->Stream())); + CUDA_CHECK(cudaMemcpyPeerAsync(DataPtr(), dst_cuda.index(), src.DataPtr(), src_cuda.index(), nbytes, + dst_cuda->Stream())); } else { - LOG(FATAL) << "Check accessibility between Device " << src_cuda->Index() << " and Device " - << dst_cuda->Index(); + LOG(FATAL) << "Check accessibility between Device " << src_cuda.index() << " and Device " + << dst_cuda.index(); } } break; } default: - LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev->Type()); + LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev.type()); } CUDA_CHECK(cudaSetDevice(current_device)); @@ -336,7 +336,7 @@ void Tensor::CopyFrom(const Tensor &src) { #endif default: - LOG(FATAL) << "Unsupported dst device type: " << static_cast(dst_dev->Type()); + LOG(FATAL) << "Unsupported dst device type: " << static_cast(dst_dev.type()); } } @@ -392,7 +392,7 @@ std::shared_ptr Tensor::Or(const std::shared_ptr &other) { } std::shared_ptr Tensor::Add(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } @@ -401,12 +401,12 @@ std::shared_ptr Tensor::Add(float scalar) { } std::shared_ptr Tensor::Sub(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } std::shared_ptr Tensor::Mul(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } @@ -415,7 +415,7 @@ std::shared_ptr Tensor::Mul(float scalar) { } std::shared_ptr Tensor::Div(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } @@ -633,7 +633,7 @@ void Tensor::Backward(std::shared_ptr gradient, bool retain_graph, bool gradient = std::make_shared(std::vector{}, dtype_, GetDevice()); gradient->Fill(1.0f); } else { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(gradient->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(gradient->GetDevice().type())); CHECK_EQ(static_cast(dtype_), static_cast(gradient->Dtype())); CHECK_EQ(dims_.size(), gradient->Dims().size()); for (int idx = 0; idx < dims_.size(); ++idx) { CHECK_EQ(dims_[idx], gradient->Dims()[idx]); } @@ -773,12 +773,12 @@ void Tensor::SaveAsNpy(const std::string &path) const { // Prepare host buffer std::vector host_buffer(num_elements); - if (GetDevice()->Type() == DeviceType::kCPU) { + if (GetDevice().type() == Device::DeviceType::kCPU) { // If on CPU, direct copy std::memcpy(host_buffer.data(), DataPtr(), num_bytes); } #ifdef USE_CUDA - else if (GetDevice()->Type() == DeviceType::kCUDA) { + else if (GetDevice().type() == Device::DeviceType::kCUDA) { // If on CUDA, copy back to host cudaDeviceSynchronize(); cudaError_t err = cudaMemcpy(host_buffer.data(), DataPtr(), num_bytes, cudaMemcpyDeviceToHost); @@ -894,11 +894,11 @@ void Tensor::Print(std::ostream &os) const { std::vector host_buffer(num_elements); - if (GetDevice()->Type() == DeviceType::kCPU) { + if (GetDevice().type() == Device::DeviceType::kCPU) { std::memcpy(host_buffer.data(), DataPtr(), num_bytes); } #ifdef USE_CUDA - else if (GetDevice()->Type() == DeviceType::kCUDA) { + else if (GetDevice().type() == Device::DeviceType::kCUDA) { cudaDeviceSynchronize(); cudaError_t err = cudaMemcpy(host_buffer.data(), DataPtr(), num_bytes, cudaMemcpyDeviceToHost); CHECK_EQ(err, cudaSuccess) << "cudaMemcpy failed: " << cudaGetErrorString(err); From abe7b96835378a170b355d2ada0df4c47e579998 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 16 Jan 2026 09:26:25 +0000 Subject: [PATCH 3/7] refactor: replace all Device* usages with Device value type --- example/common/tokenizer.cc | 7 +++-- example/common/tokenizer.h | 10 +++---- example/llama3/net.cc | 4 +-- example/llama3/net.h | 2 +- infini_train/include/autograd/comm.h | 26 +++++++++---------- infini_train/include/nn/init.h | 3 ++- infini_train/include/nn/modules/linear.h | 3 ++- infini_train/include/nn/modules/module.h | 6 ++--- .../include/nn/modules/normalization.h | 4 +-- infini_train/include/nn/modules/sparse.h | 3 ++- .../include/nn/parallel/data_parallel.h | 8 +++--- .../include/nn/parallel/parallel_functional.h | 11 ++++---- .../include/nn/parallel/pp/pipeline_stage.h | 6 +++-- .../include/nn/parallel/pp/send_recv.h | 9 ++++--- .../include/nn/parallel/process_group.h | 25 +++++++++--------- infini_train/include/nn/parallel/work.h | 8 +++--- infini_train/src/autograd/comm.cc | 10 +++---- infini_train/src/kernels/cuda/comm.cu | 9 +++---- infini_train/src/nn/init.cc | 2 +- infini_train/src/nn/modules/linear.cc | 2 +- infini_train/src/nn/modules/normalization.cc | 2 +- infini_train/src/nn/modules/sparse.cc | 5 ++-- infini_train/src/nn/parallel/data_parallel.cc | 5 ++-- .../src/nn/parallel/parallel_functional.cc | 9 +++---- .../src/nn/parallel/pp/pipeline_stage.cc | 2 +- infini_train/src/nn/parallel/pp/send_recv.cc | 19 +++++++------- infini_train/src/nn/parallel/process_group.cc | 12 ++++----- infini_train/src/nn/parallel/work.cc | 2 +- 28 files changed, 106 insertions(+), 108 deletions(-) diff --git a/example/common/tokenizer.cc b/example/common/tokenizer.cc index cfbe4df2..9541454a 100644 --- a/example/common/tokenizer.cc +++ b/example/common/tokenizer.cc @@ -10,6 +10,9 @@ #include "glog/logging.h" #include "example/common/utils.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/tensor.h" namespace infini_train { @@ -103,7 +106,7 @@ std::string Tokenizer::Decode(uint32_t token_id) const { } void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length, - uint32_t text_length, const Device *device) const { + uint32_t text_length, Device device) const { std::vector dims; dims.assign({batch_size, sequence_length}); // x_tensor (FLAGS_batch_size, FLAGS_sequence_length) eq:(4, 64) @@ -121,7 +124,7 @@ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_siz uint64_t kRngState = kRngState; LOG(INFO) << "start generate text:"; - const auto *cpu_device = Device(); + auto cpu_device = Device(); for (int t = prompt_len; t < text_length; ++t) { x = std::make_shared(x->To(device)); // CPU->calc device // TODO(jym): use no_grad forward later diff --git a/example/common/tokenizer.h b/example/common/tokenizer.h index af42dd24..c9d0b76c 100644 --- a/example/common/tokenizer.h +++ b/example/common/tokenizer.h @@ -1,15 +1,13 @@ #include #include -#include #include #include "infini_train/include/device.h" -#include "infini_train/include/nn/functional.h" -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/tensor.h" namespace infini_train { - +namespace nn { +class Module; +} class Tokenizer { public: enum class Version : uint32_t { @@ -22,7 +20,7 @@ class Tokenizer { std::string Decode(uint32_t token_id) const; void GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length, - uint32_t text_length, const Device *device) const; + uint32_t text_length, Device device) const; uint32_t GetEndToken() const { return eot_token_; }; diff --git a/example/llama3/net.cc b/example/llama3/net.cc index b92e35a8..a50fb831 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -100,7 +100,7 @@ ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, - const infini_train::Device *device = Device()) { + infini_train::Device device = Device()) { DataType dtype = DataType::kFLOAT32; CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); @@ -126,7 +126,7 @@ std::vector> SwiGLU::Forward(const std::vector(std::vector{dim}, DataType::kFLOAT32, device)->RequiresGrad(); nn::init::Ones(parameters_[kParamWeightName]); diff --git a/example/llama3/net.h b/example/llama3/net.h index 845d56e6..4496a68d 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -55,7 +55,7 @@ class RMSNorm : public infini_train::nn::CloneableModule { static constexpr char kType[] = "RMSNorm"; static constexpr char kParamWeightName[] = "weight"; - explicit RMSNorm(int64_t dim, float eps = 1e-6f, const infini_train::Device *device = infini_train::Device()); + explicit RMSNorm(int64_t dim, float eps = 1e-6f, infini_train::Device device = infini_train::Device()); std::vector> Forward(const std::vector> &x) override; diff --git a/infini_train/include/autograd/comm.h b/infini_train/include/autograd/comm.h index b54c814f..e74c821d 100644 --- a/infini_train/include/autograd/comm.h +++ b/infini_train/include/autograd/comm.h @@ -4,6 +4,7 @@ #include #include "infini_train/include/autograd/function.h" +#include "infini_train/include/device.h" namespace infini_train { class Tensor; @@ -19,7 +20,7 @@ class Scatter : public autograd::Function { public: static constexpr char kType[] = "ScatterFunction"; - explicit Scatter(const std::vector &target_gpus, int64_t dim, + explicit Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -31,8 +32,8 @@ class Scatter : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - std::vector target_gpus_; - const Device *input_device_ = nullptr; + std::vector target_gpus_; + Device input_device_ = Device(); int64_t dim_ = 0; }; @@ -40,8 +41,7 @@ class Gather : public autograd::Function { public: static constexpr char kType[] = "GatherFunction"; - explicit Gather(const Device *target_device, int64_t dim, - const infini_train::nn::parallel::ProcessGroup *pg = nullptr); + explicit Gather(Device target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -52,8 +52,8 @@ class Gather : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - const Device *target_device_ = nullptr; - std::vector input_gpus_; + Device target_device_ = Device(); + std::vector input_gpus_; int64_t dim_ = 0; bool unsqueezed_scalar_ = false; }; @@ -62,7 +62,7 @@ class Broadcast : public autograd::Function { public: static constexpr char kType[] = "BroadcastFunction"; - explicit Broadcast(const std::vector &target_gpus, + explicit Broadcast(const std::vector &target_gpus, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -74,16 +74,16 @@ class Broadcast : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - std::vector target_gpus_; + std::vector target_gpus_; int64_t num_inputs_ = 0; - const Device *input_device_ = nullptr; + Device input_device_ = Device(); }; class ReduceAddCoalesced : public autograd::Function { public: static constexpr char kType[] = "ReduceAddCoalescedFunction"; - explicit ReduceAddCoalesced(const Device *destination, int64_t num_inputs, + explicit ReduceAddCoalesced(Device destination, int64_t num_inputs, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -95,8 +95,8 @@ class ReduceAddCoalesced : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - const Device *destination_ = nullptr; - std::vector target_gpus_; + Device destination_ = Device(); + std::vector target_gpus_; int64_t num_inputs_ = 0; }; } // namespace infini_train::autograd diff --git a/infini_train/include/nn/init.h b/infini_train/include/nn/init.h index 644df590..fc6effec 100644 --- a/infini_train/include/nn/init.h +++ b/infini_train/include/nn/init.h @@ -6,6 +6,7 @@ #include #include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" namespace infini_train { class Tensor; @@ -50,5 +51,5 @@ std::shared_ptr Ones(const std::shared_ptr &tensor); std::shared_ptr Zeros(const std::shared_ptr &tensor); -std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, const Device *device = nullptr); +std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Device device = Device()); } // namespace infini_train::nn::init diff --git a/infini_train/include/nn/modules/linear.h b/infini_train/include/nn/modules/linear.h index e02b91a6..c4103df6 100644 --- a/infini_train/include/nn/modules/linear.h +++ b/infini_train/include/nn/modules/linear.h @@ -3,6 +3,7 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { @@ -18,7 +19,7 @@ class Linear : public CloneableModule { static constexpr char kParamWeightName[] = "weight"; static constexpr char kParamBiasName[] = "bias"; - Linear(int64_t in_features, int64_t out_features, bool bias = true, const Device *device = nullptr); + Linear(int64_t in_features, int64_t out_features, bool bias = true, Device device = Device()); std::vector> Forward(const std::vector> &input_tensors) override; private: diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 98dc7e5b..57e750ae 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -22,7 +22,7 @@ class Module; namespace parallel::function { std::vector> Replicate(const std::shared_ptr &network, - const std::vector &devices); + const std::vector &devices); } // namespace parallel::function class Module : public std::enable_shared_from_this { @@ -104,8 +104,8 @@ class Module : public std::enable_shared_from_this { std::vector backward_post_hooks_; private: - friend std::vector> - parallel::function::Replicate(const std::shared_ptr &network, const std::vector &devices); + friend std::vector> parallel::function::Replicate(const std::shared_ptr &network, + const std::vector &devices); }; template class CloneableModule : public Module { diff --git a/infini_train/include/nn/modules/normalization.h b/infini_train/include/nn/modules/normalization.h index 111e96b7..6119d584 100644 --- a/infini_train/include/nn/modules/normalization.h +++ b/infini_train/include/nn/modules/normalization.h @@ -3,11 +3,11 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn { @@ -17,7 +17,7 @@ class LayerNorm : public CloneableModule { static constexpr char kParamWeightName[] = "weight"; static constexpr char kParamBiasName[] = "bias"; - LayerNorm(const std::vector &normalized_shape, float eps = 1e-5f, const Device *device = nullptr); + LayerNorm(const std::vector &normalized_shape, float eps = 1e-5f, Device device = Device()); std::vector> Forward(const std::vector> &input_tensors) override; private: diff --git a/infini_train/include/nn/modules/sparse.h b/infini_train/include/nn/modules/sparse.h index e0605a6e..51975160 100644 --- a/infini_train/include/nn/modules/sparse.h +++ b/infini_train/include/nn/modules/sparse.h @@ -3,6 +3,7 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { @@ -17,7 +18,7 @@ class Embedding : public CloneableModule { static constexpr char kParamWeightName[] = "weight"; - Embedding(int num_embeddings, int embedding_dim, const Device *device = nullptr); + Embedding(int num_embeddings, int embedding_dim, Device device = Device()); std::vector> Forward(const std::vector> &input_tensors) override; private: diff --git a/infini_train/include/nn/parallel/data_parallel.h b/infini_train/include/nn/parallel/data_parallel.h index 581d6c3b..7d97f282 100644 --- a/infini_train/include/nn/parallel/data_parallel.h +++ b/infini_train/include/nn/parallel/data_parallel.h @@ -3,11 +3,11 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn::parallel { @@ -19,8 +19,8 @@ class DataParallel : public Module { private: int dim_ = 0; - std::vector devices_; - const Device *output_device_ = nullptr; - const Device *src_device_ = nullptr; + std::vector devices_; + Device output_device_ = Device(); + Device src_device_ = Device(); }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index f2559e2d..2eed56f4 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -3,12 +3,12 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" namespace infini_train { class Tensor; -class Device; namespace nn { class Module; } @@ -26,16 +26,15 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); std::vector>> Scatter(const std::vector> &input_tensors, - const std::vector &device_ids, int dim); + const std::vector &device_ids, int dim); std::vector> Gather(const std::vector>> &outputs, - const Device *target_device, int dim); + Device target_device, int dim); std::vector>> -BroadcastCoalescedReshape(const std::vector> &tensors, - const std::vector &devices); +BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices); std::vector> Replicate(const std::shared_ptr &network, - const std::vector &devices); + const std::vector &devices); } // namespace infini_train::nn::parallel::function diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index b59369dc..d1a21605 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -3,6 +3,8 @@ #include #include +#include "infini_train/include/device.h" + namespace infini_train { class Tensor; class Device; @@ -30,7 +32,7 @@ class PipelineStage { int next_rank() const; int num_stages() const; - const Device *device() const; + Device device() const; const std::vector> &recv_shape() const; const std::vector> &chunks(); std::vector> *mutable_chunks(); @@ -40,7 +42,7 @@ class PipelineStage { int num_stages_ = -1; int prev_rank_ = -1; int next_rank_ = -1; - const Device *device_ = nullptr; + Device device_ = Device(); std::vector> chunks_; std::vector> recv_shape_; }; diff --git a/infini_train/include/nn/parallel/pp/send_recv.h b/infini_train/include/nn/parallel/pp/send_recv.h index f76f4c72..4f8687ab 100644 --- a/infini_train/include/nn/parallel/pp/send_recv.h +++ b/infini_train/include/nn/parallel/pp/send_recv.h @@ -3,17 +3,18 @@ #include #include +#include "infini_train/include/device.h" + namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn::parallel { std::vector> ISend(const std::vector> &input_tensors, - const Device *target_device, int cur_rank, int peer_rank, + Device target_device, int cur_rank, int peer_rank, const std::vector> &shape); -std::vector> IRecv(const std::vector> &outputs, - const Device *src_device, int cur_rank, int peer_rank); +std::vector> IRecv(const std::vector> &outputs, Device src_device, + int cur_rank, int peer_rank); } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index d739f67d..79d84478 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -12,11 +12,11 @@ #include #endif +#include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" namespace infini_train { class Tensor; -class Device; namespace nn { class Module; namespace parallel { @@ -62,21 +62,20 @@ class ProcessGroup { BroadCast(const std::vector> &input_tensors) const = 0; virtual std::vector> - ReduceAddCoalesced(const std::vector>> &grads, const Device *destination) const - = 0; + ReduceAddCoalesced(const std::vector>> &grads, Device destination) const = 0; virtual std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const + std::vector devices, int64_t dim) const = 0; - virtual std::shared_ptr Gather(const std::vector> &tensors, - const Device *destination, int64_t dim) const + virtual std::shared_ptr Gather(const std::vector> &tensors, Device destination, + int64_t dim) const = 0; protected: ProcessGroup(int world_size, const std::string &name); - std::vector devices_; + std::vector devices_; std::unordered_map global_group_rank_map_; // global_rank : group_rank @@ -116,12 +115,12 @@ class ProcessGroupNCCL final : public ProcessGroup { std::vector> ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) const override; + Device destination) const override; - std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const override; + std::vector> Scatter(const std::shared_ptr &tensor, std::vector devices, + int64_t dim) const override; - std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, + std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) const override; private: @@ -135,8 +134,8 @@ class ProcessGroupNCCL final : public ProcessGroup { std::vector comms_; std::vector comm_streams_; - std::unordered_map device_comm_map_; - std::unordered_map device_stream_map_; + std::unordered_map device_comm_map_; + std::unordered_map device_stream_map_; }; #endif diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index 1e11cc02..8cc60f78 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -12,9 +12,7 @@ #include #endif -namespace infini_train { -class Device; -} // namespace infini_train +#include "infini_train/include/device.h" namespace infini_train::nn::parallel { @@ -39,7 +37,7 @@ class Work { #ifdef USE_NCCL class WorkNccl final : public Work { public: - WorkNccl(const Device *device, ncclComm_t comm); + WorkNccl(Device device, ncclComm_t comm); ~WorkNccl() override; bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; @@ -60,7 +58,7 @@ class WorkNccl final : public Work { void SetException(std::exception_ptr e); private: - const Device *device_ = nullptr; + Device device_ = Device(); cudaEvent_t ready_event_; cudaEvent_t done_event_; ncclComm_t comm_; diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index 48e70f09..712d8ea7 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -10,7 +10,7 @@ namespace infini_train::autograd { -Scatter::Scatter(const std::vector &target_gpus, int64_t dim, +Scatter::Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), target_gpus_(target_gpus), dim_(dim), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} @@ -32,7 +32,7 @@ std::vector> Scatter::Backward(const std::vector(input_device_, dim_)->Apply(grad_outputs); } -Gather::Gather(const Device *target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) +Gather::Gather(Device target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), target_device_(target_device), dim_(dim), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} @@ -62,10 +62,10 @@ void Gather::SetupContext(const std::vector> &input_tens std::vector> Gather::Backward(const std::vector> &grad_outputs) { // TODO(dcj): do squeeze here if unsqueezed_scalar_ is true - return std::make_shared(std::vector{input_gpus_}, dim_)->Apply(grad_outputs); + return std::make_shared(std::vector{input_gpus_}, dim_)->Apply(grad_outputs); } -Broadcast::Broadcast(const std::vector &target_gpus, const infini_train::nn::parallel::ProcessGroup *pg) +Broadcast::Broadcast(const std::vector &target_gpus, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), target_gpus_(target_gpus), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} @@ -95,7 +95,7 @@ std::vector> Broadcast::Backward(const std::vector(input_device_, num_inputs_)->Apply(grad_outputs); } -ReduceAddCoalesced::ReduceAddCoalesced(const Device *destination, int64_t num_inputs, +ReduceAddCoalesced::ReduceAddCoalesced(Device destination, int64_t num_inputs, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), destination_(destination), num_inputs_(num_inputs), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index 6fc7adeb..9c22d9d9 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -12,7 +12,7 @@ namespace infini_train::kernels::cuda { std::vector> Broadcast(const std::vector> &input_tensors, - const std::vector &devices) { + const std::vector &devices) { std::vector> outputs; for (int i = 0; i < devices.size(); ++i) { for (const auto &tensor : input_tensors) { @@ -23,7 +23,7 @@ std::vector> Broadcast(const std::vector> ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) { + Device destination) { std::vector> outputs; auto kernel = Dispatcher::Instance().GetKernel({destination->Type(), "AccumulateGrad"}); std::vector>> to_destination_grads; @@ -45,7 +45,7 @@ std::vector> ReduceAddCoalesced(const std::vector> Scatter(const std::shared_ptr &tensor, std::vector devices, +std::vector> Scatter(const std::shared_ptr &tensor, std::vector devices, int64_t dim) { std::vector> outputs; // FIXME(dcj): do split without autograd @@ -56,8 +56,7 @@ std::vector> Scatter(const std::shared_ptr &tens return outputs; } -std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, - int64_t dim) { +std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) { std::vector> outputs; for (const auto &tensor : tensors) { outputs.push_back(std::make_shared(tensor->To(destination))); } auto kernel = Dispatcher::Instance().GetKernel({tensors[0]->GetDevice()->Type(), "StackForward"}); diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 35fe3830..bf575e11 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -251,7 +251,7 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { break; \ } -std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, const Device *device) { +std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Device device) { int64_t num_elements = end - start; auto tensor = std::make_shared(std::vector{num_elements}, dtype, device); device->SetDevice(); diff --git a/infini_train/src/nn/modules/linear.cc b/infini_train/src/nn/modules/linear.cc index b6822cc1..e5a58d01 100644 --- a/infini_train/src/nn/modules/linear.cc +++ b/infini_train/src/nn/modules/linear.cc @@ -10,7 +10,7 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -Linear::Linear(int64_t in_features, int64_t out_features, bool bias, const Device *device) +Linear::Linear(int64_t in_features, int64_t out_features, bool bias, Device device) : CloneableModule(kType), bias_(bias) { device_ = device ? device : Device(); diff --git a/infini_train/src/nn/modules/normalization.cc b/infini_train/src/nn/modules/normalization.cc index 9c3273fc..df1d4afb 100644 --- a/infini_train/src/nn/modules/normalization.cc +++ b/infini_train/src/nn/modules/normalization.cc @@ -9,7 +9,7 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, const Device *device) +LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, Device device) : CloneableModule(kType), eps_(eps) { device_ = device ? device : Device(); diff --git a/infini_train/src/nn/modules/sparse.cc b/infini_train/src/nn/modules/sparse.cc index 2fdeafb8..9314fe6d 100644 --- a/infini_train/src/nn/modules/sparse.cc +++ b/infini_train/src/nn/modules/sparse.cc @@ -10,9 +10,8 @@ namespace infini_train::nn { -Embedding::Embedding(int num_embeddings, int embedding_dim, const Device *device) : CloneableModule(kType) { - device_ = device ? device : Device(); - +Embedding::Embedding(int num_embeddings, int embedding_dim, Device device) : CloneableModule(kType) { + device_ = device; parameters_[kParamWeightName] = std::make_shared(std::vector{num_embeddings, embedding_dim}, DataType::kFLOAT32, device_) ->RequiresGrad(); diff --git a/infini_train/src/nn/parallel/data_parallel.cc b/infini_train/src/nn/parallel/data_parallel.cc index b68836a4..d7899e44 100644 --- a/infini_train/src/nn/parallel/data_parallel.cc +++ b/infini_train/src/nn/parallel/data_parallel.cc @@ -20,8 +20,7 @@ constexpr char kModuleName[] = "module"; std::vector>> ParallelApply(const std::vector> &modules, - const std::vector>> &inputs, - const std::vector &devices) { + const std::vector>> &inputs, const std::vector &devices) { CHECK_EQ(modules.size(), inputs.size()) << std::format( "The number of modules {} is not equal to the number of inputs {}", modules.size(), inputs.size()); CHECK_EQ(modules.size(), devices.size()); @@ -30,7 +29,7 @@ ParallelApply(const std::vector> &modules, std::vector>>> results(modules.size(), std::nullopt); auto worker = [&](const std::shared_ptr &module, const std::vector> &inputs, - const Device *device, int idx) { + Device device, int idx) { device->SetDevice(); auto output = (*module)(inputs); results[idx] = output; diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 50408949..595b597f 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -41,7 +41,7 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const } std::vector>> Scatter(const std::vector> &input_tensors, - const std::vector &devices, int dim) { + const std::vector &devices, int dim) { std::vector>> output_tensors; for (const auto &tensor : input_tensors) { output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); @@ -56,15 +56,14 @@ std::vector>> Scatter(const std::vector> Gather(const std::vector>> &tensors, - const Device *target_device, int dim) { + Device target_device, int dim) { std::vector> gather_tensors; for (const auto &tensor : tensors) { gather_tensors.push_back(tensor[0]); } return std::make_shared(target_device, dim)->Apply(gather_tensors); } std::vector>> -BroadcastCoalescedReshape(const std::vector> &tensors, - const std::vector &devices) { +BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices) { if (tensors.empty()) { return {}; } @@ -80,7 +79,7 @@ BroadcastCoalescedReshape(const std::vector> &tensors, } std::vector> Replicate(const std::shared_ptr &network, - const std::vector &devices) { + const std::vector &devices) { const int num_replicas = devices.size(); // FIXME(dcj): Parameters function need deduplication diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index 6f02662c..25b5be23 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -34,7 +34,7 @@ int PipelineStage::prev_rank() const { return prev_rank_; } int PipelineStage::next_rank() const { return next_rank_; } int PipelineStage::num_stages() const { return num_stages_; } -const Device *PipelineStage::device() const { return device_; } +Device PipelineStage::device() const { return device_; } const std::vector> &PipelineStage::recv_shape() const { return recv_shape_; } const std::vector> &PipelineStage::chunks() { return chunks_; } std::vector> *PipelineStage::mutable_chunks() { return &chunks_; } diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index bac71f0b..afcdaac2 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -18,8 +18,7 @@ class ISend : public autograd::Function { public: static constexpr char kType[] = "ISendFunction"; - explicit ISend(const Device *target_device, int cur_rank, int peer_rank, - const std::vector> &shape) + explicit ISend(Device target_device, int cur_rank, int peer_rank, const std::vector> &shape) : autograd::Function(kType), target_device_(target_device), cur_rank_(cur_rank), peer_rank_(peer_rank), shapes_(shape) {} @@ -28,8 +27,8 @@ class ISend : public autograd::Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: - const Device *target_device_ = nullptr; - const Device *input_device_ = nullptr; + Device target_device_ = Device(); + Device input_device_ = Device(); int cur_rank_ = -1; int peer_rank_ = -1; const std::vector> &shapes_; @@ -39,7 +38,7 @@ class IRecv : public autograd::Function { public: static constexpr char kType[] = "IRecvFunction"; - explicit IRecv(const Device *src_device, int cur_rank, int peer_rank) + explicit IRecv(Device src_device, int cur_rank, int peer_rank) : autograd::Function(kType), src_device_(src_device), cur_rank_(cur_rank), peer_rank_(peer_rank) {} std::vector> Forward(const std::vector> &input_tensors) override; @@ -50,8 +49,8 @@ class IRecv : public autograd::Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: - const Device *src_device_ = nullptr; - const Device *cur_device_ = nullptr; + Device src_device_ = Device(); + Device cur_device_ = Device(); int cur_rank_ = -1; int peer_rank_ = -1; }; @@ -112,14 +111,14 @@ std::vector> IRecv::Backward(const std::vector> ISend(const std::vector> &input_tensors, - const Device *target_device, int cur_rank, int peer_rank, + Device target_device, int cur_rank, int peer_rank, const std::vector> &shape) { auto func = std::make_shared(target_device, cur_rank, peer_rank, shape); return func->Apply(input_tensors); } -std::vector> IRecv(const std::vector> &outputs, - const Device *src_device, int cur_rank, int peer_rank) { +std::vector> IRecv(const std::vector> &outputs, Device src_device, + int cur_rank, int peer_rank) { auto func = std::make_shared(src_device, cur_rank, peer_rank); return func->Apply(outputs); } diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 45a3eac7..b4f5eef0 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -380,7 +380,7 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te std::vector> outputs; std::vector streams; std::vector comms; - std::vector devices; + std::vector devices; CHECK_EQ(world_size_, comms_.size()); @@ -423,12 +423,12 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te std::vector> ProcessGroupNCCL::ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) const { + Device destination) const { // grads: [devices, tensors] std::vector> outputs; std::vector streams; std::vector comms; - std::vector devices; + std::vector devices; for (size_t i = 0; i < grads[0].size(); ++i) { outputs.push_back(std::make_shared(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination)); @@ -468,7 +468,7 @@ ProcessGroupNCCL::ReduceAddCoalesced(const std::vector> ProcessGroupNCCL::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { + std::vector devices, int64_t dim) const { std::vector> outputs; std::vector> split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); std::vector streams; @@ -503,7 +503,7 @@ std::vector> ProcessGroupNCCL::Scatter(const std::shared } std::shared_ptr ProcessGroupNCCL::Gather(const std::vector> &tensors, - const Device *destination, int64_t dim) const { + Device destination, int64_t dim) const { std::vector> outouts; int64_t num_devices = tensors.size(); auto dtype = tensors[0]->Dtype(); @@ -513,7 +513,7 @@ std::shared_ptr ProcessGroupNCCL::Gather(const std::vector streams; std::vector comms; - std::vector devices; + std::vector devices; int dest_rank = -1; for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/infini_train/src/nn/parallel/work.cc b/infini_train/src/nn/parallel/work.cc index 53fd465a..00ff18c6 100644 --- a/infini_train/src/nn/parallel/work.cc +++ b/infini_train/src/nn/parallel/work.cc @@ -15,7 +15,7 @@ std::exception_ptr makeCudaError(cudaError_t err) { } } // namespace -WorkNccl::WorkNccl(const Device *device, ncclComm_t comm) : device_(device), comm_(comm) { +WorkNccl::WorkNccl(Device device, ncclComm_t comm) : device_(device), comm_(comm) { CUDA_CHECK(cudaEventCreateWithFlags(&ready_event_, cudaEventDisableTiming)); CUDA_CHECK(cudaEventCreateWithFlags(&done_event_, cudaEventDisableTiming)); } From 154cdcaa03b8224efee5aee3ba49d85e88759a1d Mon Sep 17 00:00:00 2001 From: kilinchange Date: Sat, 17 Jan 2026 07:10:22 +0000 Subject: [PATCH 4/7] refactor: replace device->Type()/Index() with device.type()/index() --- example/gpt2/main.cc | 2 +- infini_train/src/autograd/accumulate.cc | 2 +- infini_train/src/autograd/activations.cc | 4 +- infini_train/src/autograd/comm.cc | 8 +- infini_train/src/autograd/elementwise.cc | 84 +++++++++---------- infini_train/src/autograd/function.cc | 6 +- infini_train/src/autograd/linear.cc | 4 +- infini_train/src/autograd/loss.cc | 4 +- infini_train/src/autograd/matmul.cc | 4 +- infini_train/src/autograd/misc.cc | 24 +++--- infini_train/src/autograd/normalization.cc | 4 +- infini_train/src/autograd/outer.cc | 4 +- infini_train/src/autograd/reduction.cc | 16 ++-- infini_train/src/autograd/softmax.cc | 4 +- infini_train/src/autograd/sparse.cc | 4 +- infini_train/src/autograd/transform.cc | 20 ++--- .../src/kernels/cuda/accumulate_grad.cu | 2 +- infini_train/src/kernels/cuda/comm.cu | 6 +- infini_train/src/kernels/cuda/gather.cu | 4 +- infini_train/src/nn/init.cc | 18 ++-- infini_train/src/nn/parallel/ddp/reducer.cc | 6 +- .../src/nn/parallel/parallel_functional.cc | 6 +- .../src/nn/parallel/pp/pipeline_schedule.cc | 4 +- .../src/nn/parallel/tensor_parallel.cc | 3 +- infini_train/src/optimizer.cc | 4 +- infini_train/src/tensor.cc | 2 +- 26 files changed, 124 insertions(+), 125 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index a0fac4bf..8df0df96 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -323,7 +323,7 @@ void Train(const nn::parallel::Rank &rank) { for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + infini_train::AutocastGuard autocast_guard(device.type(), dtype); // (bs, seq_len), (bs, seq_len) auto [x, y] = *train_iter; diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index def9cad8..9e1ac184 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -32,7 +32,7 @@ AccumulateGrad::Backward(const std::vector> &grad_output // NOTE(zbl): must copy, cannot change grad buffer address grad->CopyFrom(grad_output); } else { - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(grad_output, learning_rate_, grad); } } else { diff --git a/infini_train/src/autograd/activations.cc b/infini_train/src/autograd/activations.cc index 1706082b..3641865a 100644 --- a/infini_train/src/autograd/activations.cc +++ b/infini_train/src/autograd/activations.cc @@ -10,7 +10,7 @@ std::vector> Sigmoid::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SigmoidForward"}, input)}; } @@ -26,7 +26,7 @@ std::vector> Sigmoid::Backward(const std::vectorGetDevice()->Type(); + auto device = output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SigmoidBackward"}, output, grad_output)}; } } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index 712d8ea7..0e0028d0 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -18,7 +18,7 @@ Scatter::Scatter(const std::vector &target_gpus, int64_t dim, std::vector> Scatter::Forward(const std::vector> &input_tensors) { const auto &input = input_tensors[0]; std::vector> output_tensors; - auto device = input->GetDevice()->Type(); + auto device = input->GetDevice().type(); output_tensors = pg_->Scatter(input, target_gpus_, dim_); return output_tensors; } @@ -38,7 +38,7 @@ Gather::Gather(Device target_device, int64_t dim, const infini_train::nn::parall std::vector> Gather::Forward(const std::vector> &input_tensors) { for (const auto &tensor : input_tensors) { - CHECK_NE(static_cast(tensor->GetDevice()->Type()), static_cast(Device::DeviceType::kCPU)) + CHECK_NE(static_cast(tensor->GetDevice().type()), static_cast(Device::DeviceType::kCPU)) << "Gather function not implemented for CPU tensors"; } if (dim_ == 0 && input_tensors[0]->Dims().size() == 0) { @@ -51,7 +51,7 @@ std::vector> Gather::Forward(const std::vectorGetDevice()->Type(); + auto device = input_tensors[0]->GetDevice().type(); return {pg_->Gather(input_tensors, target_device_, dim_)}; } @@ -78,7 +78,7 @@ std::vector> Broadcast::Forward(const std::vectorGetDevice()->IsCPU()) << "Broadcast function not implemented for CPU tensors"; - CHECK(tensor->GetDevice()->Type() == input_device_->Type()) + CHECK(tensor->GetDevice().type() == input_device_.type()) << "Broadcast function not implemented for tensors on different device type"; } diff --git a/infini_train/src/autograd/elementwise.cc b/infini_train/src/autograd/elementwise.cc index a00536a3..7291e284 100644 --- a/infini_train/src/autograd/elementwise.cc +++ b/infini_train/src/autograd/elementwise.cc @@ -10,7 +10,7 @@ std::vector> Neg::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NegForward"}, input)}; } @@ -18,7 +18,7 @@ std::vector> Neg::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NegBackward"}, grad_output)}; } @@ -26,7 +26,7 @@ std::vector> Reciprocal::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ReciprocalForward"}, input)}; } @@ -42,7 +42,7 @@ std::vector> Reciprocal::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ReciprocalBackward"}, grad_output, input)}; } @@ -50,7 +50,7 @@ std::vector> Sin::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SinForward"}, input)}; } @@ -66,7 +66,7 @@ std::vector> Sin::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SinBackward"}, grad_output, input)}; } @@ -74,7 +74,7 @@ std::vector> Cos::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "CosForward"}, input)}; } @@ -90,7 +90,7 @@ std::vector> Cos::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "CosBackward"}, grad_output, input)}; } @@ -98,7 +98,7 @@ std::vector> Tanh::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TanhForward"}, input)}; } @@ -114,7 +114,7 @@ std::vector> Tanh::Backward(const std::vectorGetDevice()->Type(); + auto device = output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TanhBackward"}, grad_output, output)}; } @@ -122,7 +122,7 @@ std::vector> Pow::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "PowForward"}, input, exponent_, scalar_is_base_)}; } @@ -139,7 +139,7 @@ std::vector> Pow::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "PowBackward"}, grad_output, input, exponent_, scalar_is_base_)}; } @@ -148,7 +148,7 @@ std::vector> Rsqrt::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RsqrtForward"}, input)}; } @@ -164,7 +164,7 @@ std::vector> Rsqrt::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RsqrtBackward"}, grad_output, input)}; } @@ -172,7 +172,7 @@ std::vector> Exp::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ExpForward"}, input)}; } @@ -180,7 +180,7 @@ std::vector> Exp::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ExpBackward"}, grad_output)}; } @@ -188,7 +188,7 @@ std::vector> Log::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LogForward"}, input)}; } @@ -204,7 +204,7 @@ std::vector> Log::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LogBackward"}, grad_output, input)}; } @@ -213,7 +213,7 @@ std::vector> Equals::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "EqualsForward"}, input, other)}; } @@ -226,7 +226,7 @@ std::vector> EqualsScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "EqualsScalarForward"}, input, scalar_)}; } @@ -240,7 +240,7 @@ std::vector> Lt::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LtForward"}, a, b)}; } @@ -253,7 +253,7 @@ std::vector> LtScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LtScalarForward"}, input, scalar_)}; } @@ -267,7 +267,7 @@ std::vector> Le::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LeForward"}, a, b)}; } @@ -280,7 +280,7 @@ std::vector> LeScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LeScalarForward"}, input, scalar_)}; } @@ -294,7 +294,7 @@ std::vector> Gt::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GtForward"}, a, b)}; } @@ -307,7 +307,7 @@ std::vector> GtScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GtScalarForward"}, input, scalar_)}; } @@ -321,7 +321,7 @@ std::vector> Ge::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GeForward"}, a, b)}; } @@ -334,7 +334,7 @@ std::vector> GeScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GeScalarForward"}, input, scalar_)}; } @@ -348,7 +348,7 @@ std::vector> Or::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "OrForward"}, a, b)}; } @@ -362,7 +362,7 @@ std::vector> And::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AndForward"}, a, b)}; } @@ -376,7 +376,7 @@ std::vector> Add::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AddForward"}, a, b)}; } @@ -390,7 +390,7 @@ std::vector> Add::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "AddBackward"}, grad_output, a_dims_, b_dims_); return {grad_a, grad_b}; @@ -400,7 +400,7 @@ std::vector> AddScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AddScalarForward"}, input, scalar_)}; } @@ -408,7 +408,7 @@ std::vector> AddScalar::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AddScalarBackward"}, grad_output)}; } @@ -417,7 +417,7 @@ std::vector> Sub::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SubForward"}, a, b)}; } @@ -431,7 +431,7 @@ std::vector> Sub::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "SubBackward"}, grad_output, a_dims_, b_dims_); return {grad_a, grad_b}; @@ -442,7 +442,7 @@ std::vector> Mul::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MulForward"}, a, b)}; } @@ -460,7 +460,7 @@ std::vector> Mul::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "MulBackward"}, grad_output, a, b); return {grad_a, grad_b}; @@ -470,7 +470,7 @@ std::vector> MulScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MulScalarForward"}, input, scalar_)}; } @@ -478,7 +478,7 @@ std::vector> MulScalar::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MulScalarBackward"}, grad_output, scalar_)}; } @@ -487,7 +487,7 @@ std::vector> Div::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "DivForward"}, a, b)}; } @@ -505,7 +505,7 @@ std::vector> Div::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "DivBackward"}, grad_output, a, b); return {grad_a, grad_b}; diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 7f223b26..345f4d13 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -17,7 +17,7 @@ namespace infini_train::autograd { std::vector> Function::Apply(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 1); - const auto *device = input_tensors[0]->GetDevice(); + auto device = input_tensors[0]->GetDevice(); // TODO(dcj): Cache context information to reduce setDevice overhead. device->SetDevice(); @@ -88,7 +88,7 @@ std::vector> Function::Apply(const std::vector &grad_output, int grad_output_idx) { - const auto *device = grad_output->GetDevice(); + auto device = grad_output->GetDevice(); device->SetDevice(); // NOTE(dcj): The accumulate autograd function has no grad_outputs. @@ -100,7 +100,7 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g grad_outputs_[grad_output_idx] = grad_output; ++grad_outputs_reached_; } else { - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(grad_output, 1.0f, grad_outputs_.at(grad_output_idx)); } ++dependencies_reached_; diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index 53330211..be397c32 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -12,7 +12,7 @@ std::vector> Linear::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LinearForward"}, input, weight, true, bias)}; } @@ -32,7 +32,7 @@ std::vector> Linear::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto [grad_input, grad_weight, grad_bias] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( diff --git a/infini_train/src/autograd/loss.cc b/infini_train/src/autograd/loss.cc index 26e9957d..657ea649 100644 --- a/infini_train/src/autograd/loss.cc +++ b/infini_train/src/autograd/loss.cc @@ -11,7 +11,7 @@ std::vector> CrossEntropy::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "CrossEntropyForward"}, input, target)}; } @@ -29,7 +29,7 @@ std::vector> CrossEntropy::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto grad_input = Dispatcher::Instance().Call>({device, "CrossEntropyBackward"}, input, target, grad_output); return {grad_input, nullptr}; diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 68136ba0..335396d6 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -11,7 +11,7 @@ std::vector> Matmul::Forward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MatmulForward"}, input1, input2)}; } @@ -31,7 +31,7 @@ std::vector> Matmul::Backward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); auto [grad_input1, grad_input2] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "MatmulBackward"}, input1, input2, grad_output); diff --git a/infini_train/src/autograd/misc.cc b/infini_train/src/autograd/misc.cc index cdfba331..601258eb 100644 --- a/infini_train/src/autograd/misc.cc +++ b/infini_train/src/autograd/misc.cc @@ -10,7 +10,7 @@ std::vector> Split::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, split_size_, dim_)}; } @@ -23,7 +23,7 @@ void Split::SetupContext(const std::vector> &input_tenso std::vector> Split::Backward(const std::vector> &grad_outputs) { auto device = grad_outputs[0]->GetDevice(); - return {Dispatcher::Instance().Call>({device->Type(), "SplitBackward"}, input_dims_, + return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, split_size_, dim_, grad_outputs)}; } @@ -32,7 +32,7 @@ std::vector> IndexGather::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto kernel = Dispatcher::Instance().GetKernel({device, "IndexGatherForward"}); return {kernel.Call>(input, index, dim_)}; } @@ -51,7 +51,7 @@ std::vector> IndexGather::Backward(const std::vectorGetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "IndexGatherBackward"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "IndexGatherBackward"}); return {kernel.Call>(grad_output, index, dim_, input_dims_)}; } @@ -59,7 +59,7 @@ std::vector> NoOp::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; } @@ -73,7 +73,7 @@ std::vector> NoOp::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; } @@ -81,7 +81,7 @@ std::vector> Slice::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return { Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; } @@ -98,14 +98,14 @@ std::vector> Slice::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, ends_, steps_)}; } std::vector> Stack::Forward(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice()->Type(); + const auto device = input_tensors[0]->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; } @@ -119,14 +119,14 @@ void Stack::SetupContext(const std::vector> &input_tenso std::vector> Stack::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, dim_, grad_output)}; } std::vector> Concat::Forward(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice()->Type(); + const auto device = input_tensors[0]->GetDevice().type(); auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); return {kernel.Call>(input_tensors, dim_)}; @@ -140,7 +140,7 @@ void Concat::SetupContext(const std::vector> &input_tens std::vector> Concat::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); return kernel.Call>>(grad_output, input_dims_list_, dim_); } diff --git a/infini_train/src/autograd/normalization.cc b/infini_train/src/autograd/normalization.cc index 58d3bdc5..79a14abb 100644 --- a/infini_train/src/autograd/normalization.cc +++ b/infini_train/src/autograd/normalization.cc @@ -13,7 +13,7 @@ std::vector> LayerNorm::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto [output, mean, rstd] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( @@ -40,7 +40,7 @@ std::vector> LayerNorm::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto [grad_input, grad_weight, grad_bias] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( diff --git a/infini_train/src/autograd/outer.cc b/infini_train/src/autograd/outer.cc index 347df100..85a8c9ca 100644 --- a/infini_train/src/autograd/outer.cc +++ b/infini_train/src/autograd/outer.cc @@ -14,7 +14,7 @@ std::vector> Outer::Forward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "OuterForward"}, input1, input2)}; } @@ -32,7 +32,7 @@ std::vector> Outer::Backward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); auto [grad_input1, grad_input2] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "OuterBackward"}, input1, input2, grad_output); diff --git a/infini_train/src/autograd/reduction.cc b/infini_train/src/autograd/reduction.cc index e5244947..5a6e086f 100644 --- a/infini_train/src/autograd/reduction.cc +++ b/infini_train/src/autograd/reduction.cc @@ -13,7 +13,7 @@ std::vector> Mean::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MeanForward"}, input, dim_, keep_dim_)}; } @@ -27,7 +27,7 @@ std::vector> Mean::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MeanBackward"}, grad_output, input_dims_, dim_, keep_dim_)}; } @@ -36,7 +36,7 @@ std::vector> Sum::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SumForward"}, input, dim_, keep_dim_)}; } @@ -50,7 +50,7 @@ std::vector> Sum::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SumBackward"}, grad_output, input_dims_, dim_, keep_dim_)}; } @@ -59,7 +59,7 @@ std::vector> Max::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaxForward"}, input, dim_, keep_dim_)}; } @@ -77,7 +77,7 @@ std::vector> Max::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaxBackward"}, grad_output, input, reduced, dim_, keep_dim_)}; } @@ -86,7 +86,7 @@ std::vector> Min::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MinForward"}, input, dim_, keep_dim_)}; } @@ -104,7 +104,7 @@ std::vector> Min::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MinBackward"}, grad_output, input, reduced, dim_, keep_dim_)}; } diff --git a/infini_train/src/autograd/softmax.cc b/infini_train/src/autograd/softmax.cc index 1987b6f7..39569a8c 100644 --- a/infini_train/src/autograd/softmax.cc +++ b/infini_train/src/autograd/softmax.cc @@ -10,7 +10,7 @@ std::vector> Softmax::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SoftmaxForward"}, input, dim_)}; } @@ -26,7 +26,7 @@ std::vector> Softmax::Backward(const std::vectorGetDevice()->Type(); + auto device = output->GetDevice().type(); return { Dispatcher::Instance().Call>({device, "SoftmaxBackward"}, grad_output, output, dim_)}; } diff --git a/infini_train/src/autograd/sparse.cc b/infini_train/src/autograd/sparse.cc index 19867d55..93315b4f 100644 --- a/infini_train/src/autograd/sparse.cc +++ b/infini_train/src/autograd/sparse.cc @@ -11,7 +11,7 @@ std::vector> Embedding::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "EmbeddingForward"}, input, weight)}; } @@ -28,7 +28,7 @@ std::vector> Embedding::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto grad_weight = Dispatcher::Instance().Call>({device, "EmbeddingBackward"}, input, weight_dims_, grad_output); return {nullptr, grad_weight}; diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index 3c33fea3..4fae05bb 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -8,14 +8,14 @@ std::vector> Tril::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TrilForward"}, input, diagonal_)}; } std::vector> Tril::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TrilBackward"}, grad_output, diagonal_)}; } @@ -23,14 +23,14 @@ std::vector> Triu::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TriuForward"}, input, diagonal_)}; } std::vector> Triu::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TriuBackward"}, grad_output, diagonal_)}; } @@ -38,14 +38,14 @@ std::vector> Transpose::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TransposeForward"}, input, dim0_, dim1_)}; } std::vector> Transpose::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return { Dispatcher::Instance().Call>({device, "TransposeBackward"}, grad_output, dim0_, dim1_)}; } @@ -54,14 +54,14 @@ std::vector> Mask::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaskForward"}, input, mask_, value_)}; } std::vector> Mask::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaskBackward"}, grad_output, mask_)}; } @@ -70,7 +70,7 @@ RepeatInterleave::Forward(const std::vector> &input_tens CHECK_EQ(input_tensors.size(), 1); const auto &input = input_tensors[0]; - auto device = input->GetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RepeatInterleaveForward"}, input, repeat_, dim_)}; } @@ -85,7 +85,7 @@ std::vector> RepeatInterleave::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RepeatInterleaveBackward"}, grad_output, input_dims_, dim_)}; } diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index c922cb35..5ff36a2d 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -24,7 +24,7 @@ void AccumulateGrad(const std::shared_ptr &gradient, float rate, const s const auto *device = tensor->GetDevice(); - auto device_impl = GetDeviceGuardImpl(device->Type()); + auto device_impl = GetDeviceGuardImpl(device.type()); DispatchFunc( gradient->Dtype(), diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index 9c22d9d9..c3063e99 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -25,7 +25,7 @@ std::vector> Broadcast(const std::vector> ReduceAddCoalesced(const std::vector>> &grads, Device destination) { std::vector> outputs; - auto kernel = Dispatcher::Instance().GetKernel({destination->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({destination.type(), "AccumulateGrad"}); std::vector>> to_destination_grads; for (int i = 0; i < grads[0].size(); ++i) { outputs.emplace_back(std::make_shared(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination)); @@ -59,12 +59,12 @@ std::vector> Scatter(const std::shared_ptr &tens std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) { std::vector> outputs; for (const auto &tensor : tensors) { outputs.push_back(std::make_shared(tensor->To(destination))); } - auto kernel = Dispatcher::Instance().GetKernel({tensors[0]->GetDevice()->Type(), "StackForward"}); + auto kernel = Dispatcher::Instance().GetKernel({tensors[0]->GetDevice().type(), "StackForward"}); auto gathered_tensor = kernel.Call>(outputs, dim); auto old_dims = gathered_tensor->Dims(); std::vector new_dims{old_dims[0] * old_dims[1]}; for (int i = 2; i < old_dims.size(); ++i) { new_dims.push_back(old_dims[i]); } - auto view_kernel = Dispatcher::Instance().GetKernel({destination->Type(), "NoOpForward"}); + auto view_kernel = Dispatcher::Instance().GetKernel({destination.type(), "NoOpForward"}); return view_kernel.Call>(gathered_tensor, new_dims); } } // namespace infini_train::kernels::cuda diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 47d63478..a111721b 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -44,8 +44,8 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); - CHECK(input->GetDevice()->Type() == index->GetDevice()->Type()); - CHECK(input->GetDevice()->Index() == index->GetDevice()->Index()); + CHECK(input->GetDevice().type() == index->GetDevice().type()); + CHECK(input->GetDevice().index() == index->GetDevice().index()); const int64_t num_dims = in_dims.size(); if (dim < 0) { diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index bf575e11..4b704d5d 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -48,7 +48,7 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean auto device = tensor->GetDevice(); device->SetDevice(); - switch (device->Type()) { + switch (device.type()) { case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; @@ -62,7 +62,7 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean } #endif default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); break; } } @@ -154,7 +154,7 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, auto device = tensor->GetDevice(); device->SetDevice(); - switch (device->Type()) { + switch (device.type()) { case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; @@ -168,7 +168,7 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, } #endif default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); break; } } @@ -184,7 +184,7 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { auto device = tensor->GetDevice(); device->SetDevice(); - switch (device->Type()) { + switch (device.type()) { case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; @@ -198,7 +198,7 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { } #endif default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); break; } } @@ -214,7 +214,7 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { auto device = tensor->GetDevice(); device->SetDevice(); - switch (device->Type()) { + switch (device.type()) { case Device::DeviceType::kCPU: { memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); break; @@ -228,7 +228,7 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { } #endif default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); break; } } @@ -294,7 +294,7 @@ std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Devic break; } #else - LOG(FATAL) << "Unsupported device type: " << static_cast(device->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); #endif } return tensor; diff --git a/infini_train/src/nn/parallel/ddp/reducer.cc b/infini_train/src/nn/parallel/ddp/reducer.cc index 092fa74b..905362e2 100644 --- a/infini_train/src/nn/parallel/ddp/reducer.cc +++ b/infini_train/src/nn/parallel/ddp/reducer.cc @@ -26,7 +26,7 @@ void CopyGradToBucket(const std::shared_ptr &grad, const std::shared_ptr char *dst = static_cast(flat->DataPtr()) + dst_elem_offset * element_size_in_bytes; const void *src = grad->DataPtr(); - const auto dev_type = grad->GetDevice()->Type(); + const auto dev_type = grad->GetDevice().type(); if (dev_type == Device::DeviceType::kCPU) { std::memcpy(dst, src, bytes); return; @@ -52,7 +52,7 @@ void CopyBucketToGrad(const std::shared_ptr &flat, const std::shared_ptr const char *src = static_cast(flat->DataPtr()) + src_elem_offset * element_size_in_bytes; void *dst = grad->DataPtr(); - const auto dev_type = grad->GetDevice()->Type(); + const auto dev_type = grad->GetDevice().type(); if (dev_type == Device::DeviceType::kCPU) { std::memcpy(dst, src, bytes); return; @@ -135,7 +135,7 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector const auto &tensor = tensors[idx_in_order]; CHECK(tensor); - const Key k = Key{tensors[idx_in_order]->GetDevice()->Index(), tensors[idx_in_order]->Dtype()}; + const Key k = Key{tensors[idx_in_order]->GetDevice().index(), tensors[idx_in_order]->Dtype()}; auto it = states.find(k); if (it == states.end()) { it = states.emplace(k, State{}).first; diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 595b597f..719365b4 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -15,7 +15,7 @@ namespace infini_train::nn::parallel::function { std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { - auto device = tensor->GetDevice()->Type(); + auto device = tensor->GetDevice().type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } @@ -24,7 +24,7 @@ std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpT std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, const ProcessGroup *pg, bool async_op) { - auto device = output->GetDevice()->Type(); + auto device = output->GetDevice().type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } @@ -33,7 +33,7 @@ std::shared_ptr AllGather(const std::shared_ptr &output, const std std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { - auto device = output->GetDevice()->Type(); + auto device = output->GetDevice().type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 7496017b..1df2f1b5 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -213,7 +213,7 @@ float PipelineSchedule::StepMicroBatches(const std::vectordevice()->Type(), dtype); + infini_train::AutocastGuard autocast_guard(stage_->device().type(), dtype); std::vector> inputs; @@ -241,7 +241,7 @@ float PipelineSchedule::StepMicroBatches(const std::vector loss; { - infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + infini_train::AutocastGuard autocast_guard(stage_->device().type(), dtype); auto target_on_device = target->To(activations[task.local_chunk_idx][mb][0]->GetDevice()); loss = (*loss_fn)( diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 129e2f4b..611ef0ce 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -15,7 +15,6 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/utils.h" -#include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { @@ -534,7 +533,7 @@ VocabParallelCrossEntropy::Backward(const std::vector> & auto masked_target = saved_tensors_[2]; auto valid_mask_local = saved_tensors_[3]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto grad_input = Dispatcher::Instance().Call>( {device, "VocabParallelCrossEntropyBackward"}, grad_output, softmax_local, target_mask, masked_target, valid_mask_local, vocab_size_local_, vocab_size_original_, label_smoothing_); diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 80e3887f..afaf3c77 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -26,7 +26,7 @@ void SGD::Step() { } auto device = param->GetDevice(); device->SetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(param->grad(), -learning_rate_, param); } } @@ -62,7 +62,7 @@ void Adam::Step() { auto device = param->GetDevice(); device->SetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AdamAccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AdamAccumulateGrad"}); kernel.Call(grad, param, m, v, learning_rate_, beta1_, beta2_, eps_, t_); } } diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index cb351d03..0d961fe5 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -290,7 +290,7 @@ void Tensor::CopyFrom(const Tensor &src) { } #endif default: - LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev->Type()); + LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev.type()); } break; } From 9a39990f38d768bb3d3302b585dd24058dde55ca Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 6 Feb 2026 05:11:54 +0000 Subject: [PATCH 5/7] refactor: replace old Device runtime calls with DeviceGuard/impl operations --- CMakeLists.txt | 207 ++++++++++---- example/gpt2/main.cc | 11 +- example/gpt2/net.cc | 4 +- example/llama3/main.cc | 23 +- infini_train/include/autocast.h | 13 +- infini_train/include/common/common.h | 3 + infini_train/include/core/device_guard.h | 48 +--- infini_train/include/dispatcher.h | 4 - infini_train/include/nn/modules/module.h | 2 +- .../include/nn/parallel/data_parallel.h | 2 +- .../ddp/distributed_data_parallel_config.h | 2 +- .../nn/parallel/pp/pipeline_parallel.h | 4 +- .../include/nn/parallel/process_group.h | 4 +- infini_train/include/profiler.h | 8 +- infini_train/include/tensor.h | 1 + infini_train/src/autograd/accumulate.cc | 3 +- infini_train/src/autograd/comm.cc | 2 +- infini_train/src/autograd/function.cc | 7 +- infini_train/src/core/cpu/cpu_guard.cc | 2 + .../src/core/cuda/cuda_blas_handle.cc | 2 + infini_train/src/core/cuda/cuda_blas_handle.h | 2 + infini_train/src/core/cuda/cuda_guard.cc | 4 +- infini_train/src/core/cuda/cuda_stream.cc | 2 +- infini_train/src/core/cuda/cuda_stream.h | 2 +- infini_train/src/core/device_guard.cc | 50 +--- .../src/kernels/cuda/accumulate_grad.cu | 19 +- infini_train/src/kernels/cuda/cast.cu | 9 +- infini_train/src/kernels/cuda/concat.cu | 16 +- .../src/kernels/cuda/cross_entropy.cu | 22 +- infini_train/src/kernels/cuda/elementwise.cu | 38 ++- infini_train/src/kernels/cuda/embedding.cu | 18 +- infini_train/src/kernels/cuda/fill.cu | 9 +- infini_train/src/kernels/cuda/gather.cu | 17 +- infini_train/src/kernels/cuda/layernorm.cu | 17 +- infini_train/src/kernels/cuda/linear.cu | 133 +++++---- infini_train/src/kernels/cuda/outer.cu | 14 +- infini_train/src/kernels/cuda/reduction.cu | 22 +- infini_train/src/kernels/cuda/slice.cu | 15 +- infini_train/src/kernels/cuda/softmax.cu | 18 +- infini_train/src/kernels/cuda/split.cu | 15 +- infini_train/src/kernels/cuda/stack.cu | 14 +- infini_train/src/kernels/cuda/transform.cu | 70 +++-- .../cuda/vocab_parallel_cross_entropy.cu | 17 +- infini_train/src/nn/init.cc | 112 +++----- infini_train/src/nn/modules/linear.cc | 2 +- infini_train/src/nn/modules/normalization.cc | 2 +- infini_train/src/nn/parallel/data_parallel.cc | 5 +- .../parallel/ddp/distributed_data_parallel.cc | 5 +- .../nn/parallel/ddp/param_and_grad_buffer.cc | 5 +- infini_train/src/nn/parallel/ddp/reducer.cc | 57 ++-- .../src/nn/parallel/pp/pipeline_parallel.cc | 8 +- infini_train/src/nn/parallel/pp/send_recv.cc | 9 +- infini_train/src/nn/parallel/process_group.cc | 111 +++++--- .../src/nn/parallel/tensor_parallel.cc | 21 +- infini_train/src/nn/parallel/work.cc | 11 +- infini_train/src/optimizer.cc | 5 +- infini_train/src/profiler.cc | 20 +- infini_train/src/tensor.cc | 269 ++++++------------ infini_train/src/utils/precision_checker.cc | 7 +- 59 files changed, 806 insertions(+), 738 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9ff66ecf..6c160686 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,11 @@ +cmake_minimum_required(VERSION 3.28) + option(USE_CUDA "Support NVIDIA CUDA" OFF) option(PROFILE_MODE "ENABLE PROFILE MODE" OFF) option(USE_OMP "Use OpenMP as backend for Eigen" ON) option(USE_NCCL "Build project for distributed running" ON) -cmake_minimum_required(VERSION 3.28) -project(infini_train VERSION 0.3.0 LANGUAGES CXX) +project(infini_train VERSION 0.5.0 LANGUAGES CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -13,90 +14,186 @@ set(CMAKE_CXX_EXTENSIONS OFF) # Generate compile_commands.json set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -# Add gflags +# ------------------------------------------------------------------------------ +# Third-party deps +# ------------------------------------------------------------------------------ + +# gflags add_subdirectory(third_party/gflags) include_directories(${gflags_SOURCE_DIR}/include) +# glog set(WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE) set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE) - -# Add glog add_subdirectory(third_party/glog) include_directories(${glog_SOURCE_DIR}/src) -# Add eigen +# eigen if(USE_OMP) - find_package(OpenMP REQUIRED) + find_package(OpenMP REQUIRED) endif() -# find_package(OpenBLAS REQUIRED) -# include_directories(${OpenBLAS_INCLUDE_DIR}) add_subdirectory(third_party/eigen) include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen) -# add_definitions(-DEIGEN_USE_BLAS) include_directories(${PROJECT_SOURCE_DIR}) -file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc) -list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*") if(PROFILE_MODE) - add_compile_definitions(PROFILE_MODE=1) + add_compile_definitions(PROFILE_MODE=1) endif() -file (GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc) +# ------------------------------------------------------------------------------ +# Sources +# ------------------------------------------------------------------------------ + +# Framework core sources (*.cc), excluding cpu kernels (they are built separately) +file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc) +list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*") + +# CPU kernels (*.cc) +file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc) + +# ------------------------------------------------------------------------------ +# CPU kernels library +# ------------------------------------------------------------------------------ + add_library(infini_train_cpu_kernels STATIC ${CPU_KERNELS}) -target_link_libraries(infini_train_cpu_kernels glog Eigen3::Eigen) +target_link_libraries(infini_train_cpu_kernels PUBLIC glog Eigen3::Eigen) + if(USE_OMP) - add_compile_definitions(USE_OMP=1) - target_link_libraries(infini_train_cpu_kernels OpenMP::OpenMP_CXX) + add_compile_definitions(USE_OMP=1) + target_link_libraries(infini_train_cpu_kernels PUBLIC OpenMP::OpenMP_CXX) +endif() + +# ------------------------------------------------------------------------------ +# CUDA kernels library (optional) +# ------------------------------------------------------------------------------ + +if(USE_CUDA) + add_compile_definitions(USE_CUDA=1) + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) + include_directories(${CUDAToolkit_INCLUDE_DIRS}) + + # CUDA compilation options + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") + + # Only compile CUDA kernels / cuda sources here (your original used src/*.cu) + file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu) + + add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS}) + set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90") + + target_link_libraries(infini_train_cuda_kernels + PUBLIC + glog + CUDA::cudart + CUDA::cublas + CUDA::cuda_driver + ) + + if(USE_NCCL) + message(STATUS "Add USE_NCCL, use NCCL with CUDA") + list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) + find_package(NCCL REQUIRED) + add_compile_definitions(USE_NCCL=1) + target_link_libraries(infini_train_cuda_kernels PUBLIC nccl) + endif() endif() +# ------------------------------------------------------------------------------ +# Main framework library +# ------------------------------------------------------------------------------ + +add_library(infini_train STATIC ${SRC}) +target_link_libraries(infini_train + PUBLIC + glog + gflags + infini_train_cpu_kernels +) + if(USE_CUDA) - add_compile_definitions(USE_CUDA=1) - enable_language(CUDA) - find_package(CUDAToolkit REQUIRED) - include_directories(${CUDAToolkit_INCLUDE_DIRS}) - - # enable CUDA-related compilation options - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") - file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu) - add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS}) - set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90") - target_link_libraries(infini_train_cuda_kernels glog CUDA::cudart CUDA::cublas CUDA::cuda_driver) - - add_library(infini_train STATIC ${SRC}) - target_link_libraries(infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels infini_train_cuda_kernels "-Wl,--no-whole-archive") - - if (USE_NCCL) - message(STATUS "Add USE_NCCL, use NCCL with CUDA") - list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) - find_package(NCCL REQUIRED) - add_compile_definitions(USE_NCCL=1) - target_link_libraries(infini_train nccl) - endif() -else() - add_library(infini_train STATIC ${SRC}) - target_link_libraries(infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels "-Wl,--no-whole-archive") + # infini_train contains cuda runtime wrappers (*.cc) like cuda_blas_handle.cc/cuda_guard.cc + # Those may need CUDA runtime/driver/cublas symbols at final link, so attach them here too. + target_link_libraries(infini_train + PUBLIC + infini_train_cuda_kernels + CUDA::cudart + CUDA::cublas + CUDA::cuda_driver + ) + + if(USE_NCCL) + # If your core library code also directly references NCCL symbols (not only kernels), + # keep this. Otherwise it's harmless. + target_link_libraries(infini_train PUBLIC nccl) + endif() endif() +# ------------------------------------------------------------------------------ +# Helper: link libraries in a group to fix static lib one-pass resolution +# (THIS is what fixes "undefined reference" from cuda_kernels -> core symbols) +# ------------------------------------------------------------------------------ +function(link_infini_train_exe target_name) + if(USE_CUDA) + target_link_libraries(${target_name} PRIVATE + "-Wl,--start-group" + "-Wl,--whole-archive" + infini_train + infini_train_cpu_kernels + infini_train_cuda_kernels + "-Wl,--no-whole-archive" + "-Wl,--end-group" + ) + else() + target_link_libraries(${target_name} PRIVATE + "-Wl,--start-group" + "-Wl,--whole-archive" + infini_train + infini_train_cpu_kernels + "-Wl,--no-whole-archive" + "-Wl,--end-group" + ) + endif() +endfunction() + + +# ------------------------------------------------------------------------------ # Examples -add_executable(mnist example/mnist/main.cc example/mnist/dataset.cc example/mnist/net.cc) -target_link_libraries(mnist infini_train) +# ------------------------------------------------------------------------------ -add_executable(gpt2 example/gpt2/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc example/gpt2/net.cc example/common/tokenizer.cc) -target_link_libraries(gpt2 infini_train) +add_executable(mnist + example/mnist/main.cc + example/mnist/dataset.cc + example/mnist/net.cc +) +link_infini_train_exe(mnist) + +add_executable(gpt2 + example/gpt2/main.cc + example/common/tiny_shakespeare_dataset.cc + example/common/utils.cc + example/gpt2/net.cc + example/common/tokenizer.cc +) +link_infini_train_exe(gpt2) + +add_executable(llama3 + example/llama3/main.cc + example/common/tiny_shakespeare_dataset.cc + example/common/utils.cc + example/llama3/net.cc + example/common/tokenizer.cc +) +link_infini_train_exe(llama3) -add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc example/llama3/net.cc example/common/tokenizer.cc) -target_link_libraries(llama3 infini_train) +# Tools +add_subdirectory(tools/infini_run) +set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +# Tests add_executable(test_hook test/hook/test_hook.cc) target_link_libraries(test_hook infini_train) add_executable(test_precision_check test/hook/test_precision_check.cc) target_link_libraries(test_precision_check infini_train) - -add_subdirectory(tools/infini_run) - -set_target_properties(infini_run PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR} -) - diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 8df0df96..9d5d0313 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -10,6 +10,7 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/loss.h" @@ -274,7 +275,7 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; - auto cuda_device = device->IsCUDA() ? dynamic_cast(device) : nullptr; + auto impl = core::GetDeviceGuardImpl(device.type()); LOG(INFO) << "start training"; @@ -284,8 +285,8 @@ void Train(const nn::parallel::Rank &rank) { const bool last_step = step == FLAGS_num_iteration; - if (cuda_device) { - cuda_device->ResetMemPoolHighWatermarks(); + if (device.IsCUDA()) { + impl->ResetMemPoolHighWatermarks(device); } const auto iter_start = std::chrono::high_resolution_clock::now(); @@ -377,8 +378,8 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; - if (cuda_device) { - std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB(); + if (device.IsCUDA()) { + std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); } LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index 8df0bfe5..92b754df 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -199,8 +199,8 @@ GPT2FirstStage::Forward(const std::vector> int tp_rank = 0; if (tp_world_size > 1) { auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank())); - tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank()); + nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); } int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index d0bdcbd1..31354109 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -8,27 +8,28 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" +#include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/rank.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" -#include "infini_train/include/optimizer.h" -#ifdef PROFILE_MODE -#include "infini_train/include/profiler.h" -#endif -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/optimizer.h" #include "infini_train/include/utils/global_module_hook_registry.h" #include "infini_train/include/utils/precision_check_config.h" #include "infini_train/include/utils/precision_checker.h" +#ifdef PROFILE_MODE +#include "infini_train/include/profiler.h" +#endif #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" @@ -252,7 +253,7 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; - auto cuda_device = device->IsCUDA() ? dynamic_cast(device) : nullptr; + auto impl = core::GetDeviceGuardImpl(device.type()); for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { // Reset precision check counters at start of each iteration for file overwrite @@ -260,8 +261,8 @@ void Train(const nn::parallel::Rank &rank) { const bool last_step = step == FLAGS_num_iteration; - if (cuda_device) { - cuda_device->ResetMemPoolHighWatermarks(); + if (device.IsCUDA()) { + impl->ResetMemPoolHighWatermarks(device); } const auto iter_start = std::chrono::high_resolution_clock::now(); @@ -353,8 +354,8 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; - if (cuda_device) { - std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB(); + if (device.IsCUDA()) { + std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); } LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " diff --git a/infini_train/include/autocast.h b/infini_train/include/autocast.h index 3fb195c0..499c586f 100644 --- a/infini_train/include/autocast.h +++ b/infini_train/include/autocast.h @@ -3,15 +3,10 @@ #include #include -#include "common/common.h" -#include "datatype.h" -#include "device.h" -#include "tensor.h" - -#ifdef USE_CUDA -#include -#include -#endif +#include "infini_train/include/common/common.h" +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" +#include "infini_train/include/tensor.h" namespace infini_train { namespace { diff --git a/infini_train/include/common/common.h b/infini_train/include/common/common.h index 9d726d35..b6a02543 100644 --- a/infini_train/include/common/common.h +++ b/infini_train/include/common/common.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "glog/logging.h" #include "infini_train/include/datatype.h" diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h index 8f4d62b2..1ff262d0 100644 --- a/infini_train/include/core/device_guard.h +++ b/infini_train/include/core/device_guard.h @@ -39,7 +39,7 @@ enum class MemcpyKind : int8_t { // DeviceGuard (the public RAII wrapper) forwards calls to the DeviceGuardImpl // instance registered for the device type. // -// TODO(zbl): add event managemnt +// TODO(dcj): add event management // class DeviceGuardImpl { public: @@ -111,8 +111,8 @@ class DeviceGuardImpl { // • Switches to the target device // • Restores the previous device on destruction // -// All runtime operations (memory, streams, BLAS, sync) are forwarded to the -// backend-specific DeviceGuardImpl registered for that device type. +// All runtime operations are forwarded to the backend-specific DeviceGuardImpl +// instance registered for that device type. // class DeviceGuard { public: @@ -120,45 +120,25 @@ class DeviceGuard { ~DeviceGuard(); + // Copy is disallowed DeviceGuard(const DeviceGuard &) = delete; DeviceGuard &operator=(const DeviceGuard &) = delete; - // Device operations - Device GetDevice() const; + // Move is disallowed, as DeviceGuard does not have an uninitialized state, + // which is required for moves on types with nontrival destructors. + DeviceGuard(DeviceGuard &&other) = delete; + DeviceGuard &operator=(DeviceGuard &&other) = delete; - void SetDevice(Device device) const; + void SetDevice(Device device); - int8_t DeviceCount() const; + Device current_device() const; - Device::DeviceType Type() const; - - // Stream operations - Stream *GetStream(Device) const; - - // Synchronization - void SynchronizeDevice(Device) const; - - void SynchronizeStream(Stream *) const; - - // BLAS - BlasHandle *GetBlasHandle(Device) const; - - // Memory operations - void Malloc(void **dev_ptr, size_t size); - - void MallocAsync(void **dev_ptr, size_t size, Stream *stream); - - void Free(void *dev_ptr); - - void FreeAsync(void *dev_ptr, Stream *stream); - - void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind); - - void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream); + Device original_device() const; private: DeviceGuardImpl *impl_ = nullptr; Device original_device_; + Device current_device_; }; // @@ -181,13 +161,15 @@ class DeviceGuardImplRegistry { DeviceGuardImpl *Get(Device::DeviceType type) const; private: - DeviceGuardImplRegistry() = default; + DeviceGuardImplRegistry(); DeviceGuardImplRegistry(const DeviceGuardImplRegistry &) = delete; DeviceGuardImplRegistry &operator=(const DeviceGuardImplRegistry &) = delete; std::unordered_map> impls_; }; +DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type); + } // namespace infini_train::core // diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 9ebe9e2a..fc95d64e 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -260,10 +260,8 @@ auto DispatchFunc(DataType dtype, Functor &&func, std::string_view context_ident CASE_FOR_TYPE(DataType::kINT64) CASE_FOR_TYPE(DataType::kFLOAT32) CASE_FOR_TYPE(DataType::kFLOAT64) -#ifdef USE_CUDA CASE_FOR_TYPE(DataType::kBFLOAT16) CASE_FOR_TYPE(DataType::kFLOAT16) -#endif #undef CASE_FOR_TYPE } LOG_UNSUPPORTED_DTYPE(dtype, context_identifier); @@ -328,10 +326,8 @@ template st CASE_FOR_TYPE(DataType::kINT64) CASE_FOR_TYPE(DataType::kFLOAT32) CASE_FOR_TYPE(DataType::kFLOAT64) -#ifdef USE_CUDA CASE_FOR_TYPE(DataType::kBFLOAT16) CASE_FOR_TYPE(DataType::kFLOAT16) -#endif #undef CASE_FOR_TYPE } LOG_UNSUPPORTED_DTYPE(dtype, context_identifier); diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 57e750ae..6482840e 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -92,7 +92,7 @@ class Module : public std::enable_shared_from_this { std::shared_ptr RegisterBackwardPostHook(ModulePostHook hook); protected: - Device device_ = Device(); + Device device_; const std::string type_ = kUndefinedType; std::unordered_map> modules_; std::unordered_map> parameters_; diff --git a/infini_train/include/nn/parallel/data_parallel.h b/infini_train/include/nn/parallel/data_parallel.h index 7d97f282..f794d501 100644 --- a/infini_train/include/nn/parallel/data_parallel.h +++ b/infini_train/include/nn/parallel/data_parallel.h @@ -13,7 +13,7 @@ class Tensor; namespace infini_train::nn::parallel { class DataParallel : public Module { public: - DataParallel(const std::shared_ptr &module, int dim = 0); + DataParallel(const std::shared_ptr &module, int dim, Device::DeviceType device_type); std::vector> Forward(const std::vector> &input_tensors) override; diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h index b0f50b21..99d30703 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace infini_train::nn::parallel { namespace { diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index 6eef969d..25939bdc 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -30,7 +30,7 @@ struct StageInfo { class PipelineParallel : public Module { public: PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, - const std::vector> &recv_shape, int rank, int device_id, int vpp); + const std::vector> &recv_shape, int rank, Device device, int vpp); float TrainStep(const std::vector> &input, const std::vector> &target, const std::shared_ptr &optimizer, @@ -41,7 +41,7 @@ class PipelineParallel : public Module { std::vector> *mutable_chunks(); private: - void BuildPipelineStage(const std::vector> &recv_shape, int device_id, + void BuildPipelineStage(const std::vector> &recv_shape, Device device, std::vector> &&chunks); void SetupSchedule(int num_micro_batches); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 79d84478..d1c2b9e7 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -134,8 +134,8 @@ class ProcessGroupNCCL final : public ProcessGroup { std::vector comms_; std::vector comm_streams_; - std::unordered_map device_comm_map_; - std::unordered_map device_stream_map_; + std::unordered_map device_comm_map_; // device_index : comm + std::unordered_map device_stream_map_; // device_index : comm_stream }; #endif diff --git a/infini_train/include/profiler.h b/infini_train/include/profiler.h index b0aa1fbe..6e0cf06f 100644 --- a/infini_train/include/profiler.h +++ b/infini_train/include/profiler.h @@ -84,17 +84,15 @@ class Profiler { std::vector call_records_; std::string current_tag_ = "Untagged"; + // thread-local tracking + thread_local static inline std::map cpu_timing_map_; + #ifdef USE_CUDA struct EventPair { void *start; void *stop; }; -#endif - // thread-local tracking - thread_local static inline std::map cpu_timing_map_; - -#ifdef USE_CUDA thread_local static inline std::map cuda_timing_map_; #endif }; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index b3a9b042..95665b43 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -78,6 +78,7 @@ class Tensor : public std::enable_shared_from_this { size_t NumElements() const; DataType Dtype() const; + // TODO(dcj): use scalar class later template void Fill(T value); Eigen::Map> EigenMatrix(); diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index 9e1ac184..165147ef 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -3,6 +3,7 @@ #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -22,7 +23,7 @@ AccumulateGrad::Backward(const std::vector> &grad_output auto grad = tensor_->grad(); auto device = tensor_->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); if (grad_output) { if (grad) { diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index 0e0028d0..d524088a 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -77,7 +77,7 @@ std::vector> Broadcast::Forward(const std::vectorGetDevice(); for (const auto &tensor : input_tensors) { - CHECK(!tensor->GetDevice()->IsCPU()) << "Broadcast function not implemented for CPU tensors"; + CHECK(!tensor->GetDevice().IsCPU()) << "Broadcast function not implemented for CPU tensors"; CHECK(tensor->GetDevice().type() == input_device_.type()) << "Broadcast function not implemented for tensors on different device type"; } diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 345f4d13..8e5fc508 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -6,9 +6,9 @@ #include "infini_train/include/autograd/function_hook.h" #include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/common/hook.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" -#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h" #include "infini_train/include/utils/precision_check_config.h" #include "infini_train/include/utils/precision_checker.h" @@ -18,8 +18,7 @@ namespace infini_train::autograd { std::vector> Function::Apply(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 1); auto device = input_tensors[0]->GetDevice(); - // TODO(dcj): Cache context information to reduce setDevice overhead. - device->SetDevice(); + core::DeviceGuard guard(device); // Register precision check hooks if enabled (before forward) if (!precision_check_registered_) { @@ -89,7 +88,7 @@ std::vector> Function::Apply(const std::vector &grad_output, int grad_output_idx) { auto device = grad_output->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); // NOTE(dcj): The accumulate autograd function has no grad_outputs. // Temporarily resize the vector to hold one nullptr as a buffer. diff --git a/infini_train/src/core/cpu/cpu_guard.cc b/infini_train/src/core/cpu/cpu_guard.cc index 6d98d30f..263081b8 100644 --- a/infini_train/src/core/cpu/cpu_guard.cc +++ b/infini_train/src/core/cpu/cpu_guard.cc @@ -5,6 +5,8 @@ namespace infini_train::core::cpu { +CpuGuardImpl::CpuGuardImpl() {} + Device CpuGuardImpl::GetDevice() const { return Device(Device::DeviceType::kCPU, 0); } Device::DeviceType CpuGuardImpl::Type() const { return Device::DeviceType::kCPU; } diff --git a/infini_train/src/core/cuda/cuda_blas_handle.cc b/infini_train/src/core/cuda/cuda_blas_handle.cc index 38fe04cb..36da1eab 100644 --- a/infini_train/src/core/cuda/cuda_blas_handle.cc +++ b/infini_train/src/core/cuda/cuda_blas_handle.cc @@ -12,4 +12,6 @@ CudaBlasHandle::CudaBlasHandle(Stream *stream) { CUBLAS_CHECK(cublasSetStream(cublas_handle_, dynamic_cast(stream)->cuda_stream())); } +cublasHandle_t CudaBlasHandle::cublas_handle() const { return cublas_handle_; } + } // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_blas_handle.h b/infini_train/src/core/cuda/cuda_blas_handle.h index 86e1a53a..53678916 100644 --- a/infini_train/src/core/cuda/cuda_blas_handle.h +++ b/infini_train/src/core/cuda/cuda_blas_handle.h @@ -14,6 +14,8 @@ class CudaBlasHandle : public BlasHandle { public: explicit CudaBlasHandle(Stream *stream); + cublasHandle_t cublas_handle() const; + private: cublasHandle_t cublas_handle_; }; diff --git a/infini_train/src/core/cuda/cuda_guard.cc b/infini_train/src/core/cuda/cuda_guard.cc index 4ff5ecf0..be364ff2 100644 --- a/infini_train/src/core/cuda/cuda_guard.cc +++ b/infini_train/src/core/cuda/cuda_guard.cc @@ -39,7 +39,7 @@ void CudaGuardImpl::InitSingleHandle(Device device) { CUDA_CHECK(cudaGetDevice(¤t_device)); CUDA_CHECK(cudaSetDevice(device.index())); - std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device.index()); + std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device); cuda_blas_handles[device.index()] = std::make_unique(cuda_streams[device.index()].get()); @@ -85,7 +85,7 @@ void CudaGuardImpl::SynchronizeDevice(Device device) const { // blas BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const { - std::call_once(device_handle_flags.at(device.index()), InitSingleStream, device); + std::call_once(device_handle_flags.at(device.index()), InitSingleHandle, device); return cuda_blas_handles.at(device.index()).get(); } diff --git a/infini_train/src/core/cuda/cuda_stream.cc b/infini_train/src/core/cuda/cuda_stream.cc index 89319d08..82d04566 100644 --- a/infini_train/src/core/cuda/cuda_stream.cc +++ b/infini_train/src/core/cuda/cuda_stream.cc @@ -7,6 +7,6 @@ namespace infini_train::core::cuda { CudaStream::CudaStream() { CUDA_CHECK(cudaStreamCreate(&stream_)); } -cudaStream_t CudaStream::cuda_stream() { return stream_; } +cudaStream_t CudaStream::cuda_stream() const { return stream_; } } // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_stream.h b/infini_train/src/core/cuda/cuda_stream.h index 40dc7235..c5252097 100644 --- a/infini_train/src/core/cuda/cuda_stream.h +++ b/infini_train/src/core/cuda/cuda_stream.h @@ -10,7 +10,7 @@ class CudaStream : public Stream { public: CudaStream(); - cudaStream_t cuda_stream(); + cudaStream_t cuda_stream() const; private: cudaStream_t stream_; diff --git a/infini_train/src/core/device_guard.cc b/infini_train/src/core/device_guard.cc index 048d97e0..714ab6d5 100644 --- a/infini_train/src/core/device_guard.cc +++ b/infini_train/src/core/device_guard.cc @@ -11,11 +11,6 @@ #include "infini_train/src/core/cpu/cpu_guard.h" namespace infini_train::core { -namespace { -inline DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type) { - return DeviceGuardImplRegistry::Instance().Get(type); -} -} // namespace // DeviceGuardImpl void DeviceGuardImpl::SetDevice(Device device) const { @@ -78,46 +73,27 @@ DeviceGuard::DeviceGuard(Device device) : impl_(GetDeviceGuardImpl(device.type() impl_->SetDevice(device); } -DeviceGuard::~DeviceGuard() { impl_->SetDevice(original_device_); } - -Device DeviceGuard::GetDevice() const { return impl_->GetDevice(); } - -void DeviceGuard::SetDevice(Device device) const { return impl_->SetDevice(device); } - -int8_t DeviceGuard::DeviceCount() const { return impl_->DeviceCount(); } - -Device::DeviceType DeviceGuard::Type() const { return impl_->Type(); } - -Stream *DeviceGuard::GetStream(Device device) const { return impl_->GetStream(device); } - -void DeviceGuard::SynchronizeDevice(Device device) const { return impl_->SynchronizeDevice(device); } - -void DeviceGuard::SynchronizeStream(Stream *stream) const { return impl_->SynchronizeStream(stream); } - -BlasHandle *DeviceGuard::GetBlasHandle(Device device) const { return impl_->GetBlasHandle(device); } - -void DeviceGuard::Malloc(void **dev_ptr, size_t size) { impl_->Malloc(dev_ptr, size); } - -void DeviceGuard::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { - impl_->MallocAsync(dev_ptr, size, stream); +void DeviceGuard::SetDevice(Device device) { + if (current_device_ == device) { + return; + } + impl_->SetDevice(device); + current_device_ = device; } -void DeviceGuard::Free(void *dev_ptr) { impl_->Free(dev_ptr); } +Device DeviceGuard::current_device() const { return current_device_; } -void DeviceGuard::FreeAsync(void *dev_ptr, Stream *stream) { impl_->FreeAsync(dev_ptr, stream); } +Device DeviceGuard::original_device() const { return original_device_; } -void DeviceGuard::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { - impl_->Memcpy(dst, src, count, kind); -} +DeviceGuard::~DeviceGuard() { impl_->SetDevice(original_device_); } -void DeviceGuard::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { - impl_->MemcpyAsync(dst, src, count, kind, stream); +// DeviceGuardImplRegistry +DeviceGuardImplRegistry::DeviceGuardImplRegistry() { + Register(Device::DeviceType::kCPU, std::make_unique()); } -// DeviceGuardImplRegistry DeviceGuardImplRegistry &DeviceGuardImplRegistry::Instance() { static DeviceGuardImplRegistry instance; - instance.Register(Device::DeviceType::kCPU, std::make_unique()); return instance; } @@ -151,6 +127,6 @@ DeviceGuardImpl *DeviceGuardImplRegistry::Get(Device::DeviceType type) const { return it->second.get(); } -DeviceGuardImpl *GetDeviceGuard(Device::DeviceType type) { return DeviceGuardImplRegistry::Instance().Get(type); } +DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type) { return DeviceGuardImplRegistry::Instance().Get(type); } } // namespace infini_train::core diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index 5ff36a2d..d18f6320 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -5,6 +5,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -22,15 +23,15 @@ void AccumulateGrad(const std::shared_ptr &gradient, float rate, const s int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - const auto *device = tensor->GetDevice(); - - auto device_impl = GetDeviceGuardImpl(device.type()); + auto device = tensor->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( gradient->Dtype(), [=]() { - AccumulateGradKernel<<(device_impl->GetStream(device))->cuda_stream()>>>( + AccumulateGradKernel<<>>( static_cast(gradient->DataPtr()), rate, static_cast(tensor->DataPtr()), num_elements); }, "CUDA AccumulateGrad"); @@ -65,12 +66,16 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad->GetDevice()); + + auto device = grad->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( grad->Dtype(), [=]() { - AdamAccumulateGradKernel<<Stream()>>>( + AdamAccumulateGradKernel<<>>( static_cast(grad->DataPtr()), static_cast(param->DataPtr()), num_elements, static_cast(m->DataPtr()), static_cast(v->DataPtr()), learning_rate, beta1, beta2, eps, bias_correction_m, bias_correction_v); diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 0feb6dae..e4698582 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -2,10 +2,12 @@ #include "infini_train/include/common/common.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -20,7 +22,10 @@ __global__ void CastKernel(Tdst *dst, const Tsrc *src, size_t num_elements, size std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto dst_tensor = std::make_shared(input->Dims(), dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); const size_t num_elements = input->NumElements(); dim3 block_dims(256); @@ -33,7 +38,7 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto dst = static_cast(dst_tensor->DataPtr()); auto src = static_cast(input->DataPtr()); for (size_t offset = 0; offset < num_elements; offset += step) { - CastKernel<<Stream()>>>(dst, src, num_elements, offset); + CastKernel<<>>(dst, src, num_elements, offset); } }, "CUDA Cast"); diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index 6beba608..b0f239f0 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -7,8 +7,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { __device__ __forceinline__ int64_t UpperBoundI64(const int64_t *offsets, int64_t n_plus_1, int64_t x) { @@ -91,8 +93,9 @@ std::shared_ptr ConcatForward(const std::vector> std::vector host_offsets(num_inputs + 1, 0); for (int64_t i = 0; i < num_inputs; ++i) { host_offsets[i + 1] = host_offsets[i] + Ks[i]; } - const auto *cuda_device = dynamic_cast(output->GetDevice()); - const auto &stream = cuda_device->Stream(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * K_total * D; int threads_per_block = 256; @@ -175,10 +178,12 @@ std::vector> ConcatBackward(const std::shared_ptrGetDevice(); + std::vector> grads; grads.reserve(input_dims_list.size()); for (const auto &dvec : input_dims_list) { - auto t = std::make_shared(dvec, dtype, grad_output->GetDevice()); + auto t = std::make_shared(dvec, dtype, device); DispatchFunc( dtype, [=]() { t->Fill(0); }, "CUDA ConcatBackward"); grads.push_back(t); @@ -194,8 +199,9 @@ std::vector> ConcatBackward(const std::shared_ptr host_offsets(num_inputs + 1, 0); for (int64_t i = 0; i < num_inputs; ++i) { host_offsets[i + 1] = host_offsets[i] + Ks[i]; } - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); - const auto &stream = cuda_device->Stream(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * K_total * D; int threads_per_block = 256; diff --git a/infini_train/src/kernels/cuda/cross_entropy.cu b/infini_train/src/kernels/cuda/cross_entropy.cu index a9c7fda7..333f8e2b 100644 --- a/infini_train/src/kernels/cuda/cross_entropy.cu +++ b/infini_train/src/kernels/cuda/cross_entropy.cu @@ -8,8 +8,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/cub_compat.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { namespace { @@ -83,7 +85,11 @@ std::shared_ptr CrossEntropyForward(const std::shared_ptr &input constexpr int threads_per_block = 256; int num_blocks = bs; - const auto *cuda_device = dynamic_cast(target->GetDevice()); + auto device = target->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + return DispatchFunc, DataTypeList>( {target->Dtype(), input->Dtype()}, [=]() { @@ -92,8 +98,8 @@ std::shared_ptr CrossEntropyForward(const std::shared_ptr &input Tinput *batched_loss_ptr = static_cast(batched_output->DataPtr()); // FIXME(dcj): do reduce on GPU CrossEntropyForwardKernel - <<Stream()>>>(input_ptr, target_ptr, batched_loss_ptr, - bs, num_classes); + <<>>(input_ptr, target_ptr, batched_loss_ptr, bs, + num_classes); auto loss_cpu = batched_output->To(Device()); auto loss = std::make_shared(std::vector{}, input->Dtype(), Device()); @@ -186,7 +192,11 @@ std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu constexpr int threads_per_block = 256; int num_blocks = bs; - const auto *cuda_device = dynamic_cast(target->GetDevice()); + auto device = target->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc, DataTypeList>( {target->Dtype(), input_casted->Dtype()}, [=]() { @@ -196,8 +206,8 @@ std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu const Tinput *input_ptr = static_cast(input_casted->DataPtr()); Tinput *input_grad_ptr = static_cast(grad_input->DataPtr()); CrossEntropyBackwardKernel - <<Stream()>>>(input_ptr, input_grad_ptr, target_ptr, - output_grad_ptr, bs, num_classes); + <<>>(input_ptr, input_grad_ptr, target_ptr, + output_grad_ptr, bs, num_classes); }, "CUDA CrossEntropyBackward"); diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 8bbc72fc..2ebba200 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -4,8 +4,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { namespace { @@ -69,16 +71,18 @@ void LaunchKernel(Kernel &&kernel, const std::shared_ptr &output, const // Note: currently only support unary and binary operations template void LaunchForward(Func func, const std::shared_ptr &output, const Inputs &...inputs) { - const auto *cuda_device = dynamic_cast(output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); T *output_ptr = static_cast(output->DataPtr()); if constexpr (sizeof...(inputs) == 1) { // Unary case LaunchKernel( [&](dim3 grid, dim3 block, size_t offset, auto... ptrs) { - UnaryForwardKernel<<>>(output_ptr, func, output->NumElements(), offset, - ptrs...); + UnaryForwardKernel<<>>(output_ptr, func, output->NumElements(), offset, + ptrs...); }, output, inputs...); } else if constexpr (sizeof...(inputs) == 2) { @@ -102,7 +106,7 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input auto out_stride_host = ComputeStrides(out_shape); int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), stream); + cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), cuda_stream); int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; device_a_strides = device_buffer + ndim * 0; @@ -118,17 +122,18 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, + cuda_stream); LaunchKernel( [&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) { - BinaryForwardKernel<<>>( + BinaryForwardKernel<<>>( output_ptr, func, ndim, device_a_strides, device_a_shape, device_b_strides, device_b_shape, device_out_strides, a_ptr, b_ptr, output->NumElements()); }, output, inputs...); - cudaFreeAsync(device_buffer, stream); + cudaFreeAsync(device_buffer, cuda_stream); } else { static_assert(sizeof...(inputs) == 1 || sizeof...(inputs) == 2, "LaunchForward currently only supports unary and binary operations."); @@ -488,14 +493,18 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB template void LaunchBackward(Func func, const std::shared_ptr &output, const std::shared_ptr &grad_output, const Inputs &...inputs) { - const auto *cuda_device = dynamic_cast(output->GetDevice()); + auto device = output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + T *output_ptr = static_cast(output->DataPtr()); const T *grad_ptr = static_cast(grad_output->DataPtr()); LaunchKernel( [=](dim3 grid, dim3 block, size_t offset, auto... ptrs) { - UnaryBackwardKernel<<Stream()>>>(output_ptr, func, output->NumElements(), - offset, grad_ptr, ptrs...); + UnaryBackwardKernel<<>>(output_ptr, func, output->NumElements(), offset, + grad_ptr, ptrs...); }, output, inputs...); } @@ -506,8 +515,11 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out const std::shared_ptr &output_b, const std::vector &a_dims, const std::vector &b_dims, const std::shared_ptr &grad_output, const Inputs &...inputs) { - const auto *cuda_device = dynamic_cast(output_a->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = output_a->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + T *output_a_ptr = static_cast(output_a->DataPtr()); T *output_b_ptr = static_cast(output_b->DataPtr()); const T *grad_output_ptr = static_cast(grad_output->DataPtr()); diff --git a/infini_train/src/kernels/cuda/embedding.cu b/infini_train/src/kernels/cuda/embedding.cu index f43239b2..ec025098 100644 --- a/infini_train/src/kernels/cuda/embedding.cu +++ b/infini_train/src/kernels/cuda/embedding.cu @@ -1,8 +1,10 @@ #include #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -30,7 +32,11 @@ std::shared_ptr EmbeddingForward(const std::shared_ptr &input, c CHECK(input->Dtype() == DataType::kINT64); CHECK_EQ(weight->Dims().size(), 2); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int batch_size = input->Dims().size() == 2 ? input->Dims()[0] : 1; const int max_seqlen = input->Dims().size() == 2 ? input->Dims()[1] : input->Dims()[0]; const int vocab_size = weight->Dims()[0]; @@ -46,7 +52,7 @@ std::shared_ptr EmbeddingForward(const std::shared_ptr &input, c DispatchFunc( dtype, [=]() { - EmbeddingForwardKernel<<Stream()>>>( + EmbeddingForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), static_cast(weight->DataPtr()), batch_size, max_seqlen, embed_dim, vocab_size); }, @@ -77,7 +83,11 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, const std::shared_ptr &grad_output) { CHECK(input->Dtype() == DataType::kINT64); CHECK_EQ(weight_dims.size(), 2); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int vocab_size = weight_dims[0]; const int embedding_dim = weight_dims[1]; CHECK_EQ(input->Dims().size() + 1, grad_output->Dims().size()); @@ -94,7 +104,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, dtype, [=]() { grad_weight->Fill(0); - EmbeddingBackwardKernel<<Stream()>>>( + EmbeddingBackwardKernel<<>>( static_cast(input->DataPtr()), static_cast(grad_output->DataPtr()), static_cast(grad_weight->DataPtr()), num_tokens, embedding_dim, vocab_size); }, diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index 4a5d2f45..a278a93a 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -1,9 +1,11 @@ #include #include +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -19,12 +21,15 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { const int num_tokens = tensor->NumElements(); const int threads_per_block = 256; const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(tensor->GetDevice()); + auto device = tensor->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( tensor->Dtype(), [=]() { - FillKernel<<Stream()>>>( + FillKernel<<>>( static_cast(tensor->DataPtr()), *(static_cast(value_ptr)), tensor->NumElements()); }, "CUDA Fill"); diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index a111721b..382e3d5a 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -1,8 +1,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { // FIXME(zbl): This kernel aligns with torch.gather @@ -66,11 +68,13 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, << "index.size(" << d << ") must be <= input.size(" << d << ") on non-gather dims"; } - const auto *cuda_dev = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_dev->Stream(); + const auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); auto dtype = input->Dtype(); - auto out = std::make_shared(idx_dims, dtype, cuda_dev); + auto out = std::make_shared(idx_dims, dtype, device); auto in_strides = ComputeStrides(in_dims); auto out_strides = ComputeStrides(idx_dims); @@ -183,8 +187,11 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ const size_t n_out_strides = idx_dims.size(); const size_t total_i64 = n_out + n_in_strides + n_out_strides; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMallocAsync(&dev_buf, total_i64 * sizeof(int64_t), stream)); int64_t *out_dims_dev = dev_buf; int64_t *in_strides_dev = out_dims_dev + n_out; diff --git a/infini_train/src/kernels/cuda/layernorm.cu b/infini_train/src/kernels/cuda/layernorm.cu index 70d9932f..47e5654e 100644 --- a/infini_train/src/kernels/cuda/layernorm.cu +++ b/infini_train/src/kernels/cuda/layernorm.cu @@ -2,9 +2,11 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -77,13 +79,17 @@ LayerNormForward(const std::shared_ptr &input, const std::shared_ptr(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc( dtype, [=]() { mean->Fill(0); rstd->Fill(0); - LayerNormForwardKernel<<Stream()>>>( + LayerNormForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(weight->DataPtr()), static_cast(bias->DataPtr()), static_cast(mean->DataPtr()), static_cast(rstd->DataPtr()), static_cast(output->DataPtr()), eps, embed_dim); @@ -168,14 +174,17 @@ LayerNormBackward(const std::shared_ptr &input, const std::shared_ptr(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { grad_input->Fill(0); grad_weight->Fill(0); grad_bias->Fill(0); - LayerNormBackwardKernel<<Stream()>>>( + LayerNormBackwardKernel<<>>( static_cast(input->DataPtr()), static_cast(grad_output->DataPtr()), static_cast(mean->DataPtr()), static_cast(rstd->DataPtr()), static_cast(weight->DataPtr()), static_cast(grad_input->DataPtr()), diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index e5b90d3c..dbbc8cce 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -8,8 +8,11 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_blas_handle.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -39,9 +42,11 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons output_dims[output_dims.size() - 1] = n; auto output = std::make_shared(output_dims, dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); // cuBLAS is colmun-major // output = input * other --> output.T = other.T * input.T @@ -129,9 +134,11 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptr(input_promoted->GetDevice()); + auto device = input_promoted->GetDevice(); const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); { // cuBLAS is colmun-major @@ -230,7 +237,10 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons *output_dims.rbegin() = out_features; auto output = std::make_shared(output_dims, dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); if (bias) { CHECK_EQ(bias->Dims().size(), 1); @@ -241,7 +251,7 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons DispatchFunc( dtype, [=]() { - BiasCopyKernel<<Stream()>>>( + BiasCopyKernel<<>>( static_cast(output->DataPtr()), static_cast(bias->DataPtr()), bs, out_features); }, "CUDA LinearForward"); @@ -255,7 +265,9 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons auto trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; auto trans_b = CUBLAS_OP_N; auto lda = transpose ? in_features : out_features; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); // TODO(zbl): use cublasSgemv if possible for convenience and simplicity // @@ -353,7 +365,11 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr( promoted_type, [=]() { initialize_gradients(T(0), promoted_type); }, "CUDA LinearBackward"); - const auto *cuda_device = dynamic_cast(input_promoted->GetDevice()); + auto device = input_promoted->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + float alpha = 1.0f; float beta = 0.0f; auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; @@ -369,55 +385,57 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptrCublasHandle(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); switch (promoted_type) { // TODO(zbl): use cublasSgemv if possible - DISPATCH_CASE( - WRAP({ - // - if transpose: - // weight is [out_features, in_features] here - // d_input = d_output * weight --> d_input.T = weight.T * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[in_features, out_features] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // weight is [in_features, out_features] here - // d_input = d_output * weight.T --> d_input.T = weight * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[out_features, in_features] - // B = d_output.T[out_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, - static_cast(weight_promoted->DataPtr()), lda1, - static_cast(grad_output_promoted->DataPtr()), out_features, - &beta, static_cast(grad_input->DataPtr()), in_features)); - // - if transpose: - // d_weight = d_output.T * input --> d_weight.T = input.T * d_output - // C = d_weight.T[in_features, out_features] - // A = input.T[in_features, bs] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // d_weight = input.T * d_output --> d_weight.T = d_output.T * input - // C = d_weight.T[out_features, in_features] - // A = d_output.T[out_features, bs] - // B = input.T[in_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, static_cast(a2), - lda2, static_cast(b2), ldb2, &beta, - static_cast(grad_weight->DataPtr()), ldc2)); - // d_bias = \sum_i(i=0, bs-1) d_output[i] - // TODO(dcj): use thrust::fill or reduce kernel do this - if (bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<Stream()>>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } - }), - DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + // - if transpose: + // weight is [out_features, in_features] here + // d_input = d_output * weight --> d_input.T = weight.T * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[in_features, out_features] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // weight is [in_features, out_features] here + // d_input = d_output * weight.T --> d_input.T = weight * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[out_features, in_features] + // B = d_output.T[out_features, bs] + CUBLAS_CHECK(cublasSgemm(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, + static_cast(weight_promoted->DataPtr()), lda1, + static_cast(grad_output_promoted->DataPtr()), + out_features, &beta, static_cast(grad_input->DataPtr()), + in_features)); + // - if transpose: + // d_weight = d_output.T * input --> d_weight.T = input.T * d_output + // C = d_weight.T[in_features, out_features] + // A = input.T[in_features, bs] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // d_weight = input.T * d_output --> d_weight.T = d_output.T * input + // C = d_weight.T[out_features, in_features] + // A = d_output.T[out_features, bs] + // B = input.T[in_features, bs] + CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, + static_cast(a2), lda2, static_cast(b2), + ldb2, &beta, static_cast(grad_weight->DataPtr()), ldc2)); + // d_bias = \sum_i(i=0, bs-1) d_output[i] + // TODO(dcj): use thrust::fill or reduce kernel do this + if (bias) { + constexpr int BLOCK_SIZE = 256; + int threads_per_block = BLOCK_SIZE; + int num_blocks = out_features; + ReduceColumnsKernel<<>>( + static_cast(grad_output_promoted->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); + } + }), + DataType::kFLOAT32) DISPATCH_CASE(WRAP({ CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, weight_promoted->DataPtr(), CUDA_R_16BF, lda1, @@ -431,10 +449,9 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr - <<Stream()>>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); + ReduceColumnsKernel<<>>( + static_cast(grad_output_promoted->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); } }), DataType::kBFLOAT16) diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index e9716072..d595024a 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -7,8 +7,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_blas_handle.h" namespace infini_train::kernels::cuda { std::shared_ptr OuterForward(const std::shared_ptr &input, const std::shared_ptr &other) { @@ -28,14 +30,16 @@ std::shared_ptr OuterForward(const std::shared_ptr &input, const auto output = std::make_shared(std::vector{M, N}, input->Dtype(), input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); // reinterpret input: [M] as column vector [M, 1] // reinterpret other: [N] as row vector [1, N] // output[M, N] = input[M, 1] * other.T[1, N] // output.T[N, M] = other[N, 1] * input.T[1, M] float alpha = 1.0f; float beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); switch (input->Dtype()) { DISPATCH_CASE(WRAP({ @@ -97,10 +101,12 @@ std::tuple, std::shared_ptr> OuterBackward(const }, "CUDA OuterBackward"); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); float alpha = 1.0f; float beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); switch (promoted_type) { DISPATCH_CASE(WRAP({ diff --git a/infini_train/src/kernels/cuda/reduction.cu b/infini_train/src/kernels/cuda/reduction.cu index 5d8f2c15..ac5e6d20 100644 --- a/infini_train/src/kernels/cuda/reduction.cu +++ b/infini_train/src/kernels/cuda/reduction.cu @@ -3,8 +3,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/cub_compat.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { namespace { @@ -133,14 +135,18 @@ std::shared_ptr ReduceOpForward(const std::shared_ptr &input, co int threads_per_block = BLOCK_SIZE; int num_blocks = N * W; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc( dtype, [=]() { GenericReduceKernel, BLOCK_SIZE> - <<Stream()>>>(static_cast(input->DataPtr()), - static_cast(output->DataPtr()), N, H, - W, FinalizeOp{}); + <<>>(static_cast(input->DataPtr()), + static_cast(output->DataPtr()), N, H, W, + FinalizeOp{}); }, "CUDA ReductionForward"); return output; @@ -165,12 +171,16 @@ std::shared_ptr ReduceOpBackward(const std::shared_ptr &grad_out int threads_per_block = 256; int num_blocks = (N * H * W + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc( dtype, [=]() { grad_input->Fill(0); - GenericReduceBackwardKernel<<Stream()>>>( + GenericReduceBackwardKernel<<>>( static_cast(grad_input->DataPtr()), static_cast(grad_output->DataPtr()), input ? static_cast(input->DataPtr()) : nullptr, reduced ? static_cast(reduced->DataPtr()) : nullptr, N, H, W, is_mean, is_masked); diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 032ebc47..933943af 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -4,8 +4,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -65,8 +67,11 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const int64_t *new_dims_dev, *starts_dev, *steps_dev, *input_strides_dev, *output_strides_dev; - const auto *cuda_device = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaMallocAsync(&new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), stream); @@ -157,8 +162,10 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output int dims_size = dims.size(); int64_t *new_dims_dev, *starts_dev, *steps_dev, *input_strides_dev, *output_strides_dev; - const auto *cuda_device = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaMallocAsync(&new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), stream); diff --git a/infini_train/src/kernels/cuda/softmax.cu b/infini_train/src/kernels/cuda/softmax.cu index 0184dc7a..7c453ab2 100644 --- a/infini_train/src/kernels/cuda/softmax.cu +++ b/infini_train/src/kernels/cuda/softmax.cu @@ -8,8 +8,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/cub_compat.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { template @@ -87,9 +89,12 @@ void LaunchForward(const std::shared_ptr &output, const std::shared_ptr< dim3 block_dims(BLOCK_SIZE); dim3 grid_dims(outer_size, inner_size); - const auto *cuda_device = dynamic_cast(output->GetDevice()); + auto device = output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); SoftmaxForwardKernel - <<Stream()>>>(output_ptr, input_ptr, outer_size, axis_size, inner_size); + <<>>(output_ptr, input_ptr, outer_size, axis_size, inner_size); } std::shared_ptr SoftmaxForward(const std::shared_ptr &input, int64_t dim) { @@ -168,9 +173,12 @@ void LaunchBackward(const std::shared_ptr &grad_input, const std::shared dim3 block(BLOCK_SIZE); dim3 grid(outer_size, inner_size); - const auto *cuda_device = dynamic_cast(output->GetDevice()); - SoftmaxBackwardKernel<<Stream()>>>( - grad_input_ptr, grad_output_ptr, output_ptr, outer_size, axis_size, inner_size); + auto device = output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + SoftmaxBackwardKernel<<>>(grad_input_ptr, grad_output_ptr, output_ptr, + outer_size, axis_size, inner_size); } std::shared_ptr SoftmaxBackward(const std::shared_ptr &grad_output, diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index 5b2c4838..ee887b7e 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -3,8 +3,10 @@ #include #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { template @@ -50,12 +52,15 @@ std::vector> SplitForward(const std::shared_ptr int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { - SplitForwardKernel<<Stream()>>>( + SplitForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), N, H_in, H_out, W, start); }, @@ -114,8 +119,10 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di int64_t H_in = input_dims[dim]; int64_t num_splits = grad_outputs.size(); - const auto *cuda_device = dynamic_cast(grad->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = grad->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); // init the array of grad_output ptrs std::vector host_grad_output_ptrs; for (const auto &grad_output : grad_outputs) { diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index cef9a05f..d544f11d 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -7,8 +7,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { template @@ -48,8 +50,10 @@ std::shared_ptr StackForward(const std::vector> const int64_t D = std::accumulate(base_dims.begin() + dim, base_dims.end(), 1, std::multiplies()); const int64_t num_inputs = inputs.size(); - const auto *cuda_device = dynamic_cast(output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * num_inputs * D; int threads_per_block = 256; @@ -115,8 +119,10 @@ std::vector> StackBackward(const std::vector &i int64_t N = std::accumulate(input_dims.begin(), input_dims.begin() + dim, 1, std::multiplies()); int64_t D = std::accumulate(input_dims.begin() + dim, input_dims.end(), 1, std::multiplies()); - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * num_inputs * D; int threads_per_block = 256; diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 9a7cee41..d9284c46 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -6,8 +6,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -38,12 +40,15 @@ std::shared_ptr TrilForward(const std::shared_ptr &input, int64_ int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( input->Dtype(), [=]() { - TrilForwardKernel<<Stream()>>>( + TrilForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, cols, diagonal); }, "CUDA TrilForward"); @@ -78,13 +83,16 @@ std::shared_ptr TrilBackward(const std::shared_ptr &grad_output, int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { grad_input->Fill(0); - TrilBackwardKernel<<Stream()>>>( + TrilBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(grad_input->DataPtr()), rows, cols, diagonal); }, @@ -120,12 +128,15 @@ std::shared_ptr TriuForward(const std::shared_ptr &input, int64_ int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( input->Dtype(), [=]() { - TriuForwardKernel<<Stream()>>>( + TriuForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, cols, diagonal); }, "CUDA TriuForward"); @@ -159,13 +170,16 @@ std::shared_ptr TriuBackward(const std::shared_ptr &grad_output, int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { grad_input->Fill(0); - TriuBackwardKernel<<Stream()>>>( + TriuBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(grad_input->DataPtr()), rows, cols, diagonal); }, @@ -229,8 +243,10 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i out_strides[i] = out_strides[i + 1] * out_dims[i + 1]; } - const auto *cuda_device = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); // Allocate device memory for dims and strides // TODO(zbl): avoid using cudaMalloc? @@ -341,7 +357,11 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const MaskMode mode = DecideMaskMode(input_shape, mask_shape); auto output = std::make_shared(input_shape, dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(output->GetDevice()); + auto device = output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + int threads_per_block = 256; if (mode == MaskMode::kLead) { @@ -352,7 +372,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const DispatchFunc( dtype, [=]() { - MaskLeadsForwardKernel<<Stream()>>>( + MaskLeadsForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(mask->DataPtr()), static_cast(output->DataPtr()), common::cuda::Cast(value), rows, inner); }, @@ -365,7 +385,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const DispatchFunc( dtype, [=]() { - MaskForwardKernel<<Stream()>>>( + MaskForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(mask_casted->DataPtr()), static_cast(output->DataPtr()), common::cuda::Cast(value), static_cast(batch_size), static_cast(mask_size)); @@ -401,7 +421,11 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, MaskMode mode = DecideMaskMode(output_shape, mask_shape); auto grad_input = std::make_shared(output_shape, dtype, grad_output->GetDevice()); - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + int threads_per_block = 256; if (mode == MaskMode::kLead) { @@ -413,7 +437,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, dtype, [=]() { grad_input->Fill(0); - MaskLeadsBackwardKernel<<Stream()>>>( + MaskLeadsBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_casted->DataPtr()), static_cast(grad_input->DataPtr()), rows, inner); }, @@ -427,7 +451,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, dtype, [=]() { grad_input->Fill(0); - MaskBackwardKernel<<Stream()>>>( + MaskBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_casted->DataPtr()), static_cast(grad_input->DataPtr()), static_cast(batch_size), static_cast(mask_size)); }, @@ -473,12 +497,15 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i int64_t total_elements = outer * dim_size * repeat * inner; int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( input->Dtype(), [=]() { - RepeatInterleaveForwardKernel<<Stream()>>>( + RepeatInterleaveForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), outer, dim_size, inner, repeat); }, @@ -528,13 +555,16 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & int64_t total_elements = outer * dim_size * inner; int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( grad_output->Dtype(), [=]() { grad_input->Fill(0); - RepeatInterleaveBackwardKernel<<Stream()>>>( + RepeatInterleaveBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(grad_input->DataPtr()), outer, dim_size, inner, repeat); }, diff --git a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu index 75f0206c..e023cbbd 100644 --- a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu +++ b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu @@ -4,8 +4,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -74,10 +76,13 @@ VocabParallelCrossEntropyBackward(const std::shared_ptr &grad_output, dloss_is_scalar = (grad_output->NumElements() == 1); } - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); // logits should be [rows, V_local] - auto grad_input = std::make_shared(softmax_local->Dims(), softmax_local->Dtype(), cuda_device); + auto grad_input = std::make_shared(softmax_local->Dims(), softmax_local->Dtype(), device); const float one_minus_label_smoothing = 1.0f - label_smoothing; const float smoothing_term = (label_smoothing > 0.f && vocab_size_original > 0) @@ -100,10 +105,10 @@ VocabParallelCrossEntropyBackward(const std::shared_ptr &grad_output, Tinput *grad_input_ptr = static_cast(grad_input->DataPtr()); VocabParallelCrossEntropyBackwardKernel - <<Stream()>>>( - softmax_ptr, grad_input_ptr, mtarget_ptr, tmask_ptr, vml_ptr, grad_output_ptr, - static_cast(rows), static_cast(vocab_size_local), dloss_is_scalar, - one_minus_label_smoothing, smoothing_term); + <<>>(softmax_ptr, grad_input_ptr, mtarget_ptr, tmask_ptr, + vml_ptr, grad_output_ptr, static_cast(rows), + static_cast(vocab_size_local), dloss_is_scalar, + one_minus_label_smoothing, smoothing_term); }, "CUDA VocabParallelCrossEntropyBackward"); diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 4b704d5d..9cd56f7a 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -16,8 +16,10 @@ #include "glog/logging.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::init { namespace { @@ -46,26 +48,12 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean #endif auto device = tensor->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); + auto impl = core::GetDeviceGuardImpl(device.type()); - switch (device.type()) { - case Device::DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); - break; - } - } + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); return tensor; } @@ -152,26 +140,14 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, #endif auto device = tensor->GetDevice(); - device->SetDevice(); - switch (device.type()) { - case Device::DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); - break; - } - } + core::DeviceGuard guard(device); + auto impl = core::GetDeviceGuardImpl(device.type()); + + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + return tensor; } @@ -182,26 +158,14 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { std::vector buffer(num_elements, 1.0f); auto device = tensor->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); + + auto impl = core::GetDeviceGuardImpl(device.type()); + + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); - switch (device.type()) { - case Device::DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); - break; - } - } return tensor; } @@ -212,26 +176,14 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { std::vector buffer(num_elements, 0.0f); auto device = tensor->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); + + auto impl = core::GetDeviceGuardImpl(device.type()); + + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); - switch (device.type()) { - case Device::DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice().type()); - break; - } - } return tensor; } @@ -247,16 +199,18 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { std::vector buffer(num_elements); \ std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), cudaMemcpyHostToDevice, \ - dynamic_cast(device)->Stream()); \ + dynamic_cast( \ + core::GetDeviceGuardImpl(device.type())->GetStream(device)) \ + ->cuda_stream()); \ break; \ } std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Device device) { int64_t num_elements = end - start; auto tensor = std::make_shared(std::vector{num_elements}, dtype, device); - device->SetDevice(); + core::DeviceGuard guard(device); - if (device->IsCPU()) { + if (device.IsCPU()) { switch (dtype) { CASE(DataType::kUINT8, uint8_t) CASE(DataType::kINT8, int8_t) diff --git a/infini_train/src/nn/modules/linear.cc b/infini_train/src/nn/modules/linear.cc index e5a58d01..7b93fa94 100644 --- a/infini_train/src/nn/modules/linear.cc +++ b/infini_train/src/nn/modules/linear.cc @@ -12,7 +12,7 @@ namespace infini_train::nn { Linear::Linear(int64_t in_features, int64_t out_features, bool bias, Device device) : CloneableModule(kType), bias_(bias) { - device_ = device ? device : Device(); + device_ = device; parameters_[kParamWeightName] = std::make_shared(std::vector{out_features, in_features}, DataType::kFLOAT32, device_) diff --git a/infini_train/src/nn/modules/normalization.cc b/infini_train/src/nn/modules/normalization.cc index df1d4afb..73ca7b8f 100644 --- a/infini_train/src/nn/modules/normalization.cc +++ b/infini_train/src/nn/modules/normalization.cc @@ -11,7 +11,7 @@ namespace infini_train::nn { LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, Device device) : CloneableModule(kType), eps_(eps) { - device_ = device ? device : Device(); + device_ = device; parameters_[kParamWeightName] = std::make_shared(normalized_shape, DataType::kFLOAT32, device_)->RequiresGrad(); diff --git a/infini_train/src/nn/parallel/data_parallel.cc b/infini_train/src/nn/parallel/data_parallel.cc index d7899e44..2b7c3425 100644 --- a/infini_train/src/nn/parallel/data_parallel.cc +++ b/infini_train/src/nn/parallel/data_parallel.cc @@ -8,6 +8,7 @@ #include "glog/logging.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/global.h" @@ -30,7 +31,7 @@ ParallelApply(const std::vector> &modules, auto worker = [&](const std::shared_ptr &module, const std::vector> &inputs, Device device, int idx) { - device->SetDevice(); + core::DeviceGuard guard(device); auto output = (*module)(inputs); results[idx] = output; }; @@ -81,7 +82,7 @@ std::vector> DataParallel::Forward(const std::vectorGetDevice() != src_device_) { LOG(FATAL) << std::format("module must have its Parameters on device {} (device_ids[0]) but found " "one of them on device: {}", - src_device_->ToString(), tensor->GetDevice()->ToString()); + src_device_.ToString(), tensor->GetDevice().ToString()); } } diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 35a73a23..67197747 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -8,7 +8,6 @@ #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" -#include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/process_group.h" @@ -26,7 +25,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) { for (auto ¶m : module->Parameters()) { auto device = param->GetDevice(); - CHECK_EQ(device->Index(), thread_rank) << "All parameters must be on the same device as the module"; + CHECK_EQ(device.index(), thread_rank) << "All parameters must be on the same device as the module"; if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) { auto hook = std::make_unique( function::ReduceOpType::kAvg, ddp_pg_); @@ -34,7 +33,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod } } for (auto &buffer : module->Buffers()) { - CHECK_EQ(buffer->GetDevice()->Index(), thread_rank) << "All buffers must be on the same device as the module"; + CHECK_EQ(buffer->GetDevice().index(), thread_rank) << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module); diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 56984ca0..75a21f63 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -28,7 +28,7 @@ inline size_t PadTo(size_t value, size_t alignment) { return remainder == 0 ? value : value + (alignment - remainder); } -std::shared_ptr AllocateFlatBuffer(size_t num_elements, DataType data_type, const Device *device) { +std::shared_ptr AllocateFlatBuffer(size_t num_elements, DataType data_type, Device device) { std::vector dims = {static_cast(num_elements)}; // TODO(zbl): replace with united allocation when memory pool is available return std::make_shared(dims, data_type, device); @@ -103,8 +103,7 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vectorGetGroupRank(dynamic_cast(param->GetDevice())->rank().thread_rank()); + rank_in_collective_pg_ = collective_pg_->GetGroupRank(param->GetDevice().Rank().thread_rank()); } param_buffer_shard_list_.resize(buckets_.size()); diff --git a/infini_train/src/nn/parallel/ddp/reducer.cc b/infini_train/src/nn/parallel/ddp/reducer.cc index 905362e2..b7dc0fbb 100644 --- a/infini_train/src/nn/parallel/ddp/reducer.cc +++ b/infini_train/src/nn/parallel/ddp/reducer.cc @@ -5,69 +5,44 @@ #include #include -#ifdef USE_CUDA -#include -#endif - #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { namespace { -void CopyGradToBucket(const std::shared_ptr &grad, const std::shared_ptr &flat, size_t dst_elem_offset, - void *stream = nullptr) { +void CopyGradToBucket(const std::shared_ptr &grad, const std::shared_ptr &flat, + size_t dst_elem_offset) { CHECK(grad && flat); const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype()); const size_t bytes = grad->NumElements() * element_size_in_bytes; char *dst = static_cast(flat->DataPtr()) + dst_elem_offset * element_size_in_bytes; const void *src = grad->DataPtr(); - const auto dev_type = grad->GetDevice().type(); - if (dev_type == Device::DeviceType::kCPU) { - std::memcpy(dst, src, bytes); - return; - } -#ifdef USE_CUDA - if (dev_type == Device::DeviceType::kCUDA) { - auto *cuda_dev = dynamic_cast(flat->GetDevice()); - CHECK(cuda_dev); - cuda_dev->SetDevice(); - auto comm_stream = stream ? reinterpret_cast(stream) : cuda_dev->Stream(); - cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, comm_stream); - return; - } -#endif - LOG(FATAL) << "Unsupported device type in CopyGradToBucket"; + auto dev = grad->GetDevice(); + + core::DeviceGuard guard(dev); + auto impl = core::GetDeviceGuardImpl(dev.type()); + impl->MemcpyAsync(dst, src, bytes, core::MemcpyKind::kD2D, impl->GetStream(dev)); } -void CopyBucketToGrad(const std::shared_ptr &flat, const std::shared_ptr &grad, size_t src_elem_offset, - void *stream = nullptr) { +void CopyBucketToGrad(const std::shared_ptr &flat, const std::shared_ptr &grad, + size_t src_elem_offset) { CHECK(grad && flat); const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype()); const size_t bytes = grad->NumElements() * element_size_in_bytes; const char *src = static_cast(flat->DataPtr()) + src_elem_offset * element_size_in_bytes; void *dst = grad->DataPtr(); - const auto dev_type = grad->GetDevice().type(); - if (dev_type == Device::DeviceType::kCPU) { - std::memcpy(dst, src, bytes); - return; - } -#ifdef USE_CUDA - if (dev_type == Device::DeviceType::kCUDA) { - auto *cuda_dev = dynamic_cast(flat->GetDevice()); - CHECK(cuda_dev); - cuda_dev->SetDevice(); - auto comm_stream = stream ? reinterpret_cast(stream) : cuda_dev->Stream(); - cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, comm_stream); - return; - } -#endif - LOG(FATAL) << "Unsupported device type in CopyBucketToGrad"; + auto dev = grad->GetDevice(); + + core::DeviceGuard guard(dev); + auto impl = core::GetDeviceGuardImpl(dev.type()); + impl->MemcpyAsync(dst, src, bytes, core::MemcpyKind::kD2D, impl->GetStream(dev)); } std::shared_ptr MakeGradView(const std::shared_ptr &contents, size_t offset_elems, @@ -204,7 +179,7 @@ void Reducer::BuildBuckets(const std::vector> &bucket_indice CHECK(!bucket_indices[bucket_idx].empty()); const auto &first_param = params_[bucket_indices[bucket_idx][0]]; bucket.dtype = first_param->Dtype(); - bucket.device_rank = first_param->GetDevice()->rank().GlobalRank(); + bucket.device_rank = first_param->GetDevice().Rank().GlobalRank(); size_t total_elems = 0; diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index 34d2272a..c0369cde 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -17,9 +17,9 @@ constexpr char kModuleName[] = "module"; thread_local int pp_rank = 0; -void PipelineParallel::BuildPipelineStage(const std::vector> &recv_shape, int device_id, +void PipelineParallel::BuildPipelineStage(const std::vector> &recv_shape, Device device, std::vector> &&chunks) { - pipeline_stage_ = std::make_shared(rank_, num_stages_, recv_shape, device_id, std::move(chunks)); + pipeline_stage_ = std::make_shared(rank_, num_stages_, recv_shape, device, std::move(chunks)); } void PipelineParallel::SetupSchedule(int num_micro_batches) { @@ -77,7 +77,7 @@ StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int rank } PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, - const std::vector> &recv_shape, int pp_rank, int device_id, + const std::vector> &recv_shape, int pp_rank, Device device, int chunk_size) : num_stages_(num_stages), rank_(pp_rank) { modules_[kModuleName] = std::move(module); @@ -98,7 +98,7 @@ PipelineParallel::PipelineParallel(const std::shared_ptr module, int num chunks.push_back(std::make_shared(std::move(chunk_parts))); } - BuildPipelineStage(recv_shape, device_id, std::move(chunks)); + BuildPipelineStage(recv_shape, device, std::move(chunks)); SetupSchedule(num_micro_batches); } diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index afcdaac2..d17e3d9a 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -60,7 +60,7 @@ std::vector> ISend::Forward(const std::vectorGetDevice(); auto pp_group - = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_.Rank().GlobalRank())); pp_group->Send(input_tensors, peer_rank_, false); @@ -76,7 +76,7 @@ std::vector> ISend::Backward(const std::vectorGet(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_.Rank().GlobalRank())); pp_group->Recv(recv_tensors, peer_rank_, false); @@ -84,9 +84,8 @@ std::vector> ISend::Backward(const std::vector> IRecv::Forward(const std::vector> &recv_tensors) { - CHECK_NOTNULL(src_device_); auto pp_group - = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().GlobalRank())); + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_.Rank().GlobalRank())); pp_group->Recv(recv_tensors, peer_rank_, false); return recv_tensors; @@ -102,7 +101,7 @@ void IRecv::SetupContext(const std::vector> &input_tenso std::vector> IRecv::Backward(const std::vector> &grad_outputs) { auto pp_group - = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().GlobalRank())); + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_.Rank().GlobalRank())); pp_group->Send(grad_outputs, peer_rank_, false); diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index b4f5eef0..cb1e07b8 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -18,11 +18,13 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train { @@ -130,8 +132,8 @@ void ProcessGroupNCCL::InitSingleProcess(const std::vector &ranks) { for (int i = 0; i < ranks.size(); ++i) { auto device = Device(Device::DeviceType::kCUDA, ranks[i]); devices_.push_back(device); - device_comm_map_[device] = comms_[i]; - global_group_rank_map_[device->rank().GlobalRank()] = i; + device_comm_map_[device.index()] = comms_[i]; + global_group_rank_map_[device.Rank().GlobalRank()] = i; } } @@ -166,9 +168,9 @@ void ProcessGroupNCCL::InitMultiProcess(const std::vector &ranks) { comms_.push_back(comm); auto device = Device(Device::DeviceType::kCUDA, i); - global_group_rank_map_[device->rank().GlobalRank()] = group_rank; + global_group_rank_map_[device.Rank().GlobalRank()] = group_rank; devices_.push_back(device); - device_comm_map_[device] = comm; + device_comm_map_[device.index()] = comm; } } NCCL_CHECK(ncclGroupEnd()); @@ -179,24 +181,27 @@ void ProcessGroupNCCL::InitStreams() { comm_streams_.resize(device_size); for (int i = 0; i < device_size; ++i) { - devices_[i]->SetDevice(); + core::DeviceGuard guard(devices_[i]); + int low, high; CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&low, &high)); CUDA_CHECK(cudaStreamCreateWithPriority(&comm_streams_[i], cudaStreamNonBlocking, high)); - device_stream_map_[devices_[i]] = comm_streams_[i]; + device_stream_map_[devices_[i].index()] = comm_streams_[i]; } } std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op, bool async_op) const { void *buffer = tensor->DataPtr(); - const auto *device = dynamic_cast(tensor->GetDevice()); - device->SetDevice(); + auto device = tensor->GetDevice(); + core::DeviceGuard guard(device); - auto comm = device_comm_map_.at(device); + auto comm = device_comm_map_.at(device.index()); - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaStream_t comm_stream = device_stream_map_.at(device.index()); auto work = std::make_shared(device, comm); @@ -222,13 +227,15 @@ std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr &output, const std::shared_ptr &input, bool async_op) const { - const auto *device = dynamic_cast(input->GetDevice()); - auto comm = device_comm_map_.at(device); + auto device = input->GetDevice(); + auto comm = device_comm_map_.at(device.index()); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaStream_t comm_stream = device_stream_map_.at(device.index()); auto work = std::make_shared(device, comm); @@ -254,13 +261,15 @@ std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, function::ReduceOpType reduce_op, bool async_op) const { - const auto *device = dynamic_cast(input->GetDevice()); - auto comm = device_comm_map_.at(device); + auto device = input->GetDevice(); + auto comm = device_comm_map_.at(device.index()); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaStream_t comm_stream = device_stream_map_.at(device.index()); auto work = std::make_shared(device, comm); @@ -286,13 +295,15 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr ProcessGroupNCCL::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); - const auto *device = dynamic_cast(tensors[0]->GetDevice()); - auto comm = device_comm_map_.at(device); + auto device = tensors[0]->GetDevice(); + auto comm = device_comm_map_.at(device.index()); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaStream_t comm_stream = device_stream_map_.at(device.index()); auto work = std::make_shared(device, comm); @@ -332,13 +343,15 @@ std::shared_ptr ProcessGroupNCCL::Send(std::vector std::shared_ptr ProcessGroupNCCL::Recv(std::vector> tensors, int src_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); - const auto *device = dynamic_cast(tensors[0]->GetDevice()); - auto comm = device_comm_map_.at(device); + auto device = tensors[0]->GetDevice(); + auto comm = device_comm_map_.at(device.index()); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaStream_t comm_stream = device_stream_map_.at(device.index()); auto work = std::make_shared(device, comm); @@ -390,8 +403,10 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te outputs.push_back(std::make_shared(input_tensor->Dims(), input_tensor->Dtype(), device)); } devices.push_back(device); - streams.push_back(dynamic_cast(device)->Stream()); - comms.push_back(device_comm_map_.at(device)); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream()); + comms.push_back(device_comm_map_.at(device.index())); } int root = -1; @@ -405,7 +420,8 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < devices.size(); ++i) { - devices[i]->SetDevice(); + core::DeviceGuard guard(devices[i]); + for (size_t j = 0; j < input_tensors.size(); ++j) { const auto &input_tensor = input_tensors[j]; const auto dtype = input_tensor->Dtype(); @@ -436,8 +452,10 @@ ProcessGroupNCCL::ReduceAddCoalesced(const std::vectorGetDevice()); - streams.push_back(dynamic_cast(devices[i])->Stream()); - comms.push_back(device_comm_map_.at(devices[i])); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(devices[i].type())->GetStream(devices[i])) + ->cuda_stream()); + comms.push_back(device_comm_map_.at(devices[i].index())); } int root = -1; @@ -451,7 +469,8 @@ ProcessGroupNCCL::ReduceAddCoalesced(const std::vectorSetDevice(); + core::DeviceGuard guard(devices[i]); + for (size_t j = 0; j < grads[i].size(); ++j) { const auto &grad = grads[i][j]; const auto dtype = grad->Dtype(); @@ -479,8 +498,10 @@ std::vector> ProcessGroupNCCL::Scatter(const std::shared src_rank = i; } outputs.push_back(std::make_shared(split_tensors[i]->Dims(), split_tensors[i]->Dtype(), devices[i])); - streams.push_back(dynamic_cast(devices[i])->Stream()); - comms.push_back(device_comm_map_.at(devices[i])); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(devices[i].type())->GetStream(devices[i])) + ->cuda_stream()); + comms.push_back(device_comm_map_.at(devices[i].index())); } CHECK_NE(src_rank, -1) << "Source device not found in input devices"; @@ -490,7 +511,8 @@ std::vector> ProcessGroupNCCL::Scatter(const std::shared auto nccl_dtype = kNcclDtypeMap.at(dtype); for (size_t i = 0; i < devices.size(); ++i) { - devices[i]->SetDevice(); + core::DeviceGuard guard(devices[i]); + const auto dtype = tensor->Dtype(); auto nccl_dtype = kNcclDtypeMap.at(dtype); NCCL_CHECK(ncclSend(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), nccl_dtype, i, @@ -521,8 +543,10 @@ std::shared_ptr ProcessGroupNCCL::Gather(const std::vector(device)->Stream()); - comms.push_back(device_comm_map_.at(device)); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream()); + comms.push_back(device_comm_map_.at(device.index())); devices.push_back(device); total_dim += tensors[i]->Dims()[dim]; @@ -538,7 +562,8 @@ std::shared_ptr ProcessGroupNCCL::Gather(const std::vectorSetDevice(); + core::DeviceGuard guard(devices[i]); + auto &tensor = tensors[i]; size_t num_elements = tensor->NumElements(); void *send_ptr = tensor->DataPtr(); diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 611ef0ce..2542d590 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -34,8 +34,7 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso } auto device = tensor->GetDevice(); - auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); + auto tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); std::vector output_shape = tensor->Dims(); output_shape[0] *= world_size; @@ -54,8 +53,7 @@ std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor } auto device = tensor->GetDevice(); - auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); + auto tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); std::vector output_shape = tensor->Dims(); output_shape[0] *= world_size; @@ -79,9 +77,8 @@ std::shared_ptr SplitAlongLastDim(const std::shared_ptr &tensor) } auto device = tensor->GetDevice(); - auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); - auto rank = tp_group->GetGroupRank(device->rank().GlobalRank()); + auto tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + auto rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); auto last_dim_size = tensor->Dims().back() / world_size; auto shards = tensor->Split(last_dim_size, -1); @@ -97,8 +94,7 @@ std::shared_ptr Reduce(const std::shared_ptr &tensor) { } auto device = tensor->GetDevice(); - auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); + auto tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); auto output = std::make_shared(*tensor); @@ -115,8 +111,7 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr } auto device = tensor->GetDevice(); - auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); + auto tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); auto output_shape = tensor->Dims(); CHECK_EQ(output_shape[0] % world_size, 0) << "First dimension of the tensor should be divisible by TP world size"; @@ -436,8 +431,8 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i const ProcessGroup *tp_group = nullptr; int rank = 0; if (tp_size > 1) { - tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); - rank = tp_group->GetGroupRank(device->rank().GlobalRank()); + tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); } vocab_size_local_ = logits->Dims().back(); diff --git a/infini_train/src/nn/parallel/work.cc b/infini_train/src/nn/parallel/work.cc index 00ff18c6..57018258 100644 --- a/infini_train/src/nn/parallel/work.cc +++ b/infini_train/src/nn/parallel/work.cc @@ -5,7 +5,9 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::parallel { #ifdef USE_NCCL @@ -31,7 +33,7 @@ WorkNccl::~WorkNccl() { bool WorkNccl::WaitBlocking(std::chrono::milliseconds timeout) { // Block wait on host - device_->SetDevice(); + core::DeviceGuard guard(device_); // If timeout is not set, then wait till it finishes if (timeout <= std::chrono::milliseconds::zero()) { @@ -68,8 +70,11 @@ bool WorkNccl::WaitBlocking(std::chrono::milliseconds timeout) { bool WorkNccl::WaitNonBlocking() { // Non-blocking wait on compute stream - device_->SetDevice(); - CUDA_CHECK(cudaStreamWaitEvent(dynamic_cast(device_)->Stream(), done_event_, 0)); + core::DeviceGuard guard(device_); + CUDA_CHECK(cudaStreamWaitEvent(dynamic_cast( + core::GetDeviceGuardImpl(device_.type())->GetStream(device_)) + ->cuda_stream(), + done_event_, 0)); return true; } diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index afaf3c77..8eacafa3 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -2,6 +2,7 @@ #include +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -25,7 +26,7 @@ void SGD::Step() { continue; } auto device = param->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(param->grad(), -learning_rate_, param); } @@ -61,7 +62,7 @@ void Adam::Step() { auto &v = v_[i]; auto device = param->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AdamAccumulateGrad"}); kernel.Call(grad, param, m, v, learning_rate_, beta1_, beta2_, eps_, t_); } diff --git a/infini_train/src/profiler.cc b/infini_train/src/profiler.cc index 6464c24c..cbcb470a 100644 --- a/infini_train/src/profiler.cc +++ b/infini_train/src/profiler.cc @@ -14,7 +14,9 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train { namespace { @@ -39,24 +41,18 @@ Profiler &Profiler::Instance() { } int GetRank(Device::DeviceType device) { - if (device == Device::DeviceType::kCPU) { - return 0; - } - - // Assume single-node setting, rank == device_id - int device_id = 0; -#ifdef USE_CUDA - CUDA_CHECK(cudaGetDevice(&device_id)); -#endif - return device_id; + auto impl = core::GetDeviceGuardImpl(device); + return impl->GetDevice().index(); } #ifdef USE_CUDA cudaStream_t GetCudaStream() { int device_id = GetRank(Device::DeviceType::kCUDA); // TODO(zbl): support multi-stream on single device - return dynamic_cast(Device(Device::DeviceType::kCUDA, static_cast(device_id))) - ->Stream(); + auto device = Device(Device::DeviceType::kCUDA, static_cast(device_id)); + return dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); } #endif diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 0d961fe5..774e36ad 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -1,5 +1,4 @@ #include "infini_train/include/tensor.h" -#include "infini_train/include/datatype.h" #include #include @@ -7,16 +6,9 @@ #include #include -#ifdef USE_CUDA -#include -#endif - #include "Eigen/Dense" #include "glog/logging.h" -#ifdef USE_CUDA -#include "infini_train/include/common/cuda/common_cuda.h" -#endif #include "infini_train/include/autograd/accumulate.h" #include "infini_train/include/autograd/elementwise.h" #include "infini_train/include/autograd/function.h" @@ -26,49 +18,23 @@ #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/transform.h" +#include "infini_train/include/core/device_guard.h" +#include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/init.h" namespace infini_train { TensorBuffer::TensorBuffer(Device device, size_t size) : device_(device), size_(size) { - CHECK_NOTNULL(device); - switch (device.type()) { - case Device::DeviceType::kCPU: - data_ = malloc(size); - break; -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); - // TODO(dcj): Maybe pin memory later. - device->SetDevice(); - const auto *cuda_device = dynamic_cast(device); - CUDA_CHECK(cudaMallocAsync(&data_, size, cuda_device->Stream())); - CUDA_CHECK(cudaSetDevice(current_device)); - break; - } -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device_.type()); - break; - } + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + impl->MallocAsync(&data_, size, impl->GetStream(device)); } TensorBuffer::~TensorBuffer() { - switch (device_.type()) { - case Device::DeviceType::kCPU: - free(data_); - break; -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: - CUDA_CHECK(cudaFreeAsync(data_, dynamic_cast(device_)->Stream())); - break; -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device_.type()); - break; - } + core::DeviceGuard guard(device_); + auto *impl = core::GetDeviceGuardImpl(device_.type()); + impl->FreeAsync(data_, impl->GetStream(device_)); } void *TensorBuffer::DataPtr() { return data_; } @@ -98,19 +64,12 @@ Tensor::Tensor(const float *data, const std::vector &dims, DataType dty CHECK(dtype == DataType::kFLOAT32); buffer_ = std::make_shared(device, kDataTypeToSize.at(dtype) * num_elements_); - switch (device.type()) { - case Device::DeviceType::kCPU: - memcpy(buffer_->DataPtr(), data, buffer_->Size()); - break; -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: - CUDA_CHECK(cudaMemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream())); - break; -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); - } + + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + impl->MemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); } void Tensor::SetData(const Tensor &tensor, size_t offset, bool preserve_data) { @@ -145,7 +104,7 @@ DataType Tensor::Dtype() const { return dtype_; } template void Tensor::Fill(T value) { auto device = GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); DataType dtype = Dtype(); @@ -188,7 +147,8 @@ Eigen::Map> Tensor::Eig } Tensor Tensor::To(Device device) { - if (device == buffer_->GetDevice()) { + const auto buffer_device = buffer_->GetDevice(); + if (device == buffer_device) { auto new_tensor = Tensor(*this, offset_, dims_); if (grad_) { new_tensor.grad_ = std::make_unique(*grad_.get(), grad_->offset_, grad_->dims_); @@ -197,39 +157,31 @@ Tensor Tensor::To(Device device) { } Tensor new_tensor; - switch (device.type()) { -#ifdef USE_CUDA - case Device::DeviceType::kCPU: { - // CUDA -> CPU - GetDevice()->SetDevice(); + if (device.type() == Device::DeviceType::kCPU) { + // D2H new_tensor = Tensor(dims_, dtype_, Device()); - CUDA_CHECK(cudaMemcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), cudaMemcpyDeviceToHost)); - break; - } - case Device::DeviceType::kCUDA: { - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); + core::DeviceGuard guard(buffer_device); + auto impl = core::GetDeviceGuardImpl(buffer_device.type()); + impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H, + impl->GetStream(buffer_device)); + + } else if (buffer_device.type() == Device::DeviceType::kCPU) { new_tensor = Tensor(dims_, dtype_, device); - if (GetDevice().type() == Device::DeviceType::kCPU) { - device->SetDevice(); - // CPU -> CUDA - CUDA_CHECK(cudaMemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream())); - } else { - // p2p - // 1. CUDA -> CPU - // 2. CPU -> CUDA - Tensor cpu_tensor = To(Device()); - device->SetDevice(); - CUDA_CHECK(cudaMemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), - cudaMemcpyHostToDevice, dynamic_cast(device)->Stream())); - } - CUDA_CHECK(cudaSetDevice(current_device)); - break; - } -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); + // H2D + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, + impl->GetStream(device)); + } else { + new_tensor = Tensor(dims_, dtype_, device); + // P2P + // 1. D2H + Tensor cpu_tensor = To(Device()); + // 2. H2D + core::DeviceGuard guard(buffer_device); + auto *impl = core::GetDeviceGuardImpl(buffer_device.type()); + impl->MemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, + impl->GetStream(buffer_device)); } if (grad_) { @@ -251,7 +203,7 @@ Tensor Tensor::To(DataType dtype) { } auto device = GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); auto kernel = Dispatcher::Instance().GetKernel({device.type(), "Cast"}); auto new_tensor = *kernel.Call>(shared_from_this(), dtype); @@ -275,68 +227,30 @@ void Tensor::CopyFrom(const Tensor &src) { const Device dst_dev = GetDevice(); const Device src_dev = src.GetDevice(); - switch (dst_dev.type()) { - case Device::DeviceType::kCPU: { - switch (src_dev.type()) { - case Device::DeviceType::kCPU: { - std::memcpy(DataPtr(), src.DataPtr(), nbytes); - break; - } -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - // CUDA -> CPU - CUDA_CHECK(cudaMemcpy(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToHost)); - break; - } -#endif - default: - LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev.type()); - } - break; - } - -#ifdef USE_CUDA - case Device::DeviceType::kCUDA: { - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); - dst_dev->SetDevice(); - - const auto *dst_cuda = dynamic_cast(dst_dev); - switch (src_dev.type()) { - case Device::DeviceType::kCPU: { - // CPU -> CUDA - CUDA_CHECK(cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyHostToDevice, dst_cuda->Stream())); - break; - } - case Device::DeviceType::kCUDA: { - const auto *src_cuda = dynamic_cast(src_dev); - if (src_cuda.index() == dst_cuda.index()) { - CUDA_CHECK( - cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToDevice, dst_cuda->Stream())); - } else { - int canAccessPeer = 0; - CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, dst_cuda.index(), src_cuda.index())); - if (canAccessPeer) { - CUDA_CHECK(cudaMemcpyPeerAsync(DataPtr(), dst_cuda.index(), src.DataPtr(), src_cuda.index(), nbytes, - dst_cuda->Stream())); - } else { - LOG(FATAL) << "Check accessibility between Device " << src_cuda.index() << " and Device " - << dst_cuda.index(); - } - } - break; - } - default: - LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev.type()); - } - - CUDA_CHECK(cudaSetDevice(current_device)); - break; - } -#endif - - default: - LOG(FATAL) << "Unsupported dst device type: " << static_cast(dst_dev.type()); + if (dst_dev == src_dev) { + core::DeviceGuard guard(dst_dev); + auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D, impl->GetStream(dst_dev)); + } else if (dst_dev.type() == Device::DeviceType::kCPU) { + // D2H + core::DeviceGuard guard(src_dev); + auto *impl = core::GetDeviceGuardImpl(src_dev.type()); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H, impl->GetStream(src_dev)); + } else if (src_dev.type() == Device::DeviceType::kCPU) { + // H2D + core::DeviceGuard guard(dst_dev); + auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + } else { + // TODO(dcj): maybe support p2p api later + // P2P + // 1. D2H + Tensor cpu_tensor(dims_, dtype_, Device()); + cpu_tensor.CopyFrom(src); + // 2. H2D + core::DeviceGuard guard(dst_dev); + auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); + impl->MemcpyAsync(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); } } @@ -771,23 +685,14 @@ void Tensor::SaveAsNpy(const std::string &path) const { const size_t num_bytes = num_elements * sizeof(float); // Prepare host buffer - std::vector host_buffer(num_elements); + auto impl = core::GetDeviceGuardImpl(GetDevice().type()); - if (GetDevice().type() == Device::DeviceType::kCPU) { - // If on CPU, direct copy - std::memcpy(host_buffer.data(), DataPtr(), num_bytes); - } -#ifdef USE_CUDA - else if (GetDevice().type() == Device::DeviceType::kCUDA) { - // If on CUDA, copy back to host - cudaDeviceSynchronize(); - cudaError_t err = cudaMemcpy(host_buffer.data(), DataPtr(), num_bytes, cudaMemcpyDeviceToHost); - CHECK_EQ(err, cudaSuccess) << "cudaMemcpy failed: " << cudaGetErrorString(err); - } -#endif - else { - LOG(FATAL) << "Unsupported device type for SaveAsNpy."; - } + impl->SynchronizeDevice(GetDevice()); + + Tensor cpu_tensor(dims_, dtype_, Device()); + cpu_tensor.CopyFrom(*this); + + impl->SynchronizeDevice(GetDevice()); // Write .npy file std::ofstream file(path, std::ios::binary); @@ -829,7 +734,7 @@ void Tensor::SaveAsNpy(const std::string &path) const { file.write(header.c_str(), header.size()); // Write data - file.write(reinterpret_cast(host_buffer.data()), num_bytes); + file.write(reinterpret_cast(cpu_tensor.DataPtr()), num_bytes); file.close(); } @@ -892,21 +797,16 @@ void Tensor::Print(std::ostream &os) const { const size_t num_elements = NumElements(); const size_t num_bytes = num_elements * sizeof(float); - std::vector host_buffer(num_elements); + auto impl = core::GetDeviceGuardImpl(GetDevice().type()); - if (GetDevice().type() == Device::DeviceType::kCPU) { - std::memcpy(host_buffer.data(), DataPtr(), num_bytes); - } -#ifdef USE_CUDA - else if (GetDevice().type() == Device::DeviceType::kCUDA) { - cudaDeviceSynchronize(); - cudaError_t err = cudaMemcpy(host_buffer.data(), DataPtr(), num_bytes, cudaMemcpyDeviceToHost); - CHECK_EQ(err, cudaSuccess) << "cudaMemcpy failed: " << cudaGetErrorString(err); - } -#endif - else { - LOG(FATAL) << "Unsupported device type for Print."; - } + impl->SynchronizeDevice(GetDevice()); + + Tensor cpu_tensor(dims_, dtype_, Device()); + cpu_tensor.CopyFrom(*this); + + impl->SynchronizeDevice(GetDevice()); + + const float *buffer = static_cast(cpu_tensor.DataPtr()); const PrintOptions &opts = PrintOptions::Get(); const int64_t precision = opts.precision; @@ -917,7 +817,8 @@ void Tensor::Print(std::ostream &os) const { bool use_sci = opts.sci_mode.value_or(false); if (!opts.sci_mode.has_value()) { - for (float v : host_buffer) { + for (int idx = 0; idx < NumElements(); ++idx) { + const auto v = buffer[idx]; float abs_v = std::fabs(v); if ((abs_v > 0.0f && abs_v < 1e-4f) || abs_v >= 1e+4f) { use_sci = true; @@ -940,7 +841,7 @@ void Tensor::Print(std::ostream &os) const { std::vector str_vals(num_elements); size_t max_width = 0; for (size_t i = 0; i < num_elements; ++i) { - str_vals[i] = format_float(host_buffer[i]); + str_vals[i] = format_float(buffer[i]); max_width = std::max(max_width, str_vals[i].length()); } diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index 3391b9a8..f1825f14 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -1,6 +1,5 @@ #include "infini_train/include/utils/precision_checker.h" -#include #include #include #include @@ -12,7 +11,6 @@ #include #include #include -#include #include "infini_train/include/autograd/function.h" #include "infini_train/include/nn/modules/module.h" @@ -295,9 +293,8 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string // Copy tensor to CPU if it's on GPU std::shared_ptr cpu_tensor; - if (tensor->GetDevice()->Type() == DeviceType::kCUDA) { - auto cpu_device = DeviceManager::Instance()->GetDevice(DeviceType::kCPU); - cpu_tensor = std::make_shared(tensor->To(cpu_device)); + if (tensor->GetDevice().IsCUDA()) { + cpu_tensor = std::make_shared(tensor->To(Device())); } else { cpu_tensor = tensor; } From 7a43321ada8e6e9d433d6bb0bbad5161a4e37b32 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 9 Feb 2026 08:09:48 +0000 Subject: [PATCH 6/7] refactor: clean up device runtime interfaces and initialization semantics - Split internal implementation headers into a separate include group - Drop redundant explicit default initialization for Device - Add `impl` suffix to CUDA guard implementation files - Unify Arange initialization via DeviceGuardImpl --- infini_train/include/autograd/comm.h | 8 +- infini_train/include/core/device_guard.h | 3 +- .../include/nn/parallel/data_parallel.h | 4 +- .../include/nn/parallel/pp/pipeline_stage.h | 2 +- infini_train/include/nn/parallel/work.h | 2 +- infini_train/include/tensor.h | 2 +- .../cpu/{cpu_guard.cc => cpu_guard_impl.cc} | 15 +++- .../cpu/{cpu_guard.h => cpu_guard_impl.h} | 0 .../{cuda_guard.cc => cuda_guard_impl.cc} | 2 +- .../cuda/{cuda_guard.h => cuda_guard_impl.h} | 7 +- infini_train/src/core/device_guard.cc | 6 +- .../src/kernels/cuda/accumulate_grad.cu | 1 + infini_train/src/kernels/cuda/cast.cu | 1 + infini_train/src/kernels/cuda/concat.cu | 1 + .../src/kernels/cuda/cross_entropy.cu | 1 + infini_train/src/kernels/cuda/elementwise.cu | 1 + infini_train/src/kernels/cuda/embedding.cu | 1 + infini_train/src/kernels/cuda/fill.cu | 1 + infini_train/src/kernels/cuda/gather.cu | 1 + infini_train/src/kernels/cuda/layernorm.cu | 1 + infini_train/src/kernels/cuda/linear.cu | 1 + infini_train/src/kernels/cuda/outer.cu | 1 + infini_train/src/kernels/cuda/reduction.cu | 1 + infini_train/src/kernels/cuda/slice.cu | 1 + infini_train/src/kernels/cuda/softmax.cu | 1 + infini_train/src/kernels/cuda/split.cu | 1 + infini_train/src/kernels/cuda/stack.cu | 1 + infini_train/src/kernels/cuda/transform.cu | 1 + .../cuda/vocab_parallel_cross_entropy.cu | 1 + infini_train/src/nn/init.cc | 83 +++++++------------ infini_train/src/nn/parallel/pp/send_recv.cc | 8 +- infini_train/src/nn/parallel/process_group.cc | 1 + infini_train/src/nn/parallel/work.cc | 1 + infini_train/src/profiler.cc | 1 + scripts/test_config.json | 3 +- 35 files changed, 87 insertions(+), 79 deletions(-) rename infini_train/src/core/cpu/{cpu_guard.cc => cpu_guard_impl.cc} (53%) rename infini_train/src/core/cpu/{cpu_guard.h => cpu_guard_impl.h} (100%) rename infini_train/src/core/cuda/{cuda_guard.cc => cuda_guard_impl.cc} (98%) rename infini_train/src/core/cuda/{cuda_guard.h => cuda_guard_impl.h} (93%) diff --git a/infini_train/include/autograd/comm.h b/infini_train/include/autograd/comm.h index e74c821d..ec3cfe4a 100644 --- a/infini_train/include/autograd/comm.h +++ b/infini_train/include/autograd/comm.h @@ -33,7 +33,7 @@ class Scatter : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; std::vector target_gpus_; - Device input_device_ = Device(); + Device input_device_; int64_t dim_ = 0; }; @@ -52,7 +52,7 @@ class Gather : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - Device target_device_ = Device(); + Device target_device_; std::vector input_gpus_; int64_t dim_ = 0; bool unsqueezed_scalar_ = false; @@ -76,7 +76,7 @@ class Broadcast : public autograd::Function { const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; std::vector target_gpus_; int64_t num_inputs_ = 0; - Device input_device_ = Device(); + Device input_device_; }; class ReduceAddCoalesced : public autograd::Function { @@ -95,7 +95,7 @@ class ReduceAddCoalesced : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - Device destination_ = Device(); + Device destination_; std::vector target_gpus_; int64_t num_inputs_ = 0; }; diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h index 1ff262d0..de43b376 100644 --- a/infini_train/include/core/device_guard.h +++ b/infini_train/include/core/device_guard.h @@ -11,6 +11,7 @@ namespace infini_train::core { class Stream; class BlasHandle; +// Note(dcj): In the CPU backend, kD2D corresponds to a regular memcpy. enum class MemcpyKind : int8_t { kH2D = 0, kD2H = 1, @@ -161,7 +162,7 @@ class DeviceGuardImplRegistry { DeviceGuardImpl *Get(Device::DeviceType type) const; private: - DeviceGuardImplRegistry(); + DeviceGuardImplRegistry() = default; DeviceGuardImplRegistry(const DeviceGuardImplRegistry &) = delete; DeviceGuardImplRegistry &operator=(const DeviceGuardImplRegistry &) = delete; diff --git a/infini_train/include/nn/parallel/data_parallel.h b/infini_train/include/nn/parallel/data_parallel.h index f794d501..a1e6a57e 100644 --- a/infini_train/include/nn/parallel/data_parallel.h +++ b/infini_train/include/nn/parallel/data_parallel.h @@ -20,7 +20,7 @@ class DataParallel : public Module { private: int dim_ = 0; std::vector devices_; - Device output_device_ = Device(); - Device src_device_ = Device(); + Device output_device_; + Device src_device_; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index d1a21605..8dc0b8d3 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -42,7 +42,7 @@ class PipelineStage { int num_stages_ = -1; int prev_rank_ = -1; int next_rank_ = -1; - Device device_ = Device(); + Device device_; std::vector> chunks_; std::vector> recv_shape_; }; diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index 8cc60f78..c6be5127 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -58,7 +58,7 @@ class WorkNccl final : public Work { void SetException(std::exception_ptr e); private: - Device device_ = Device(); + Device device_; cudaEvent_t ready_event_; cudaEvent_t done_event_; ncclComm_t comm_; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 95665b43..a40d0987 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -48,7 +48,7 @@ class TensorBuffer { size_t Size() const; private: - Device device_ = Device(); + Device device_; size_t size_ = 0; void *data_ = nullptr; }; diff --git a/infini_train/src/core/cpu/cpu_guard.cc b/infini_train/src/core/cpu/cpu_guard_impl.cc similarity index 53% rename from infini_train/src/core/cpu/cpu_guard.cc rename to infini_train/src/core/cpu/cpu_guard_impl.cc index 263081b8..9e3d9ec5 100644 --- a/infini_train/src/core/cpu/cpu_guard.cc +++ b/infini_train/src/core/cpu/cpu_guard_impl.cc @@ -1,8 +1,12 @@ -#include "infini_train/src/core/cpu/cpu_guard.h" +#include "infini_train/src/core/cpu/cpu_guard_impl.h" #include #include +#include "glog/logging.h" + +#include "infini_train/include/core/device_guard.h" + namespace infini_train::core::cpu { CpuGuardImpl::CpuGuardImpl() {} @@ -15,6 +19,13 @@ void CpuGuardImpl::Malloc(void **dev_ptr, size_t size) { *dev_ptr = std::malloc( void CpuGuardImpl::Free(void *dev_ptr) { std::free(dev_ptr); } -void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { std::memcpy(dst, src, count); } +void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { + CHECK(kind == MemcpyKind::kD2D) << "CpuGuardImpl::Memcpy only supports kD2D (host-to-host) memcpy, " + << "but got MemcpyKind=" << static_cast(kind); + + std::memcpy(dst, src, count); +} + +INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCPU, CpuGuardImpl) } // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cpu/cpu_guard.h b/infini_train/src/core/cpu/cpu_guard_impl.h similarity index 100% rename from infini_train/src/core/cpu/cpu_guard.h rename to infini_train/src/core/cpu/cpu_guard_impl.h diff --git a/infini_train/src/core/cuda/cuda_guard.cc b/infini_train/src/core/cuda/cuda_guard_impl.cc similarity index 98% rename from infini_train/src/core/cuda/cuda_guard.cc rename to infini_train/src/core/cuda/cuda_guard_impl.cc index be364ff2..0a4a58ef 100644 --- a/infini_train/src/core/cuda/cuda_guard.cc +++ b/infini_train/src/core/cuda/cuda_guard_impl.cc @@ -1,4 +1,4 @@ -#include "infini_train/src/core/cuda/cuda_guard.h" +#include "infini_train/src/core/cuda/cuda_guard_impl.h" #include #include diff --git a/infini_train/src/core/cuda/cuda_guard.h b/infini_train/src/core/cuda/cuda_guard_impl.h similarity index 93% rename from infini_train/src/core/cuda/cuda_guard.h rename to infini_train/src/core/cuda/cuda_guard_impl.h index 400bb0da..af84570b 100644 --- a/infini_train/src/core/cuda/cuda_guard.h +++ b/infini_train/src/core/cuda/cuda_guard_impl.h @@ -2,11 +2,14 @@ #include -#include "infini_train/include/core/blas_handle.h" #include "infini_train/include/core/device_guard.h" -#include "infini_train/include/core/stream.h" #include "infini_train/include/device.h" +namespace infini_train::core { +class Stream; +class BlasHandle; +} // namespace infini_train::core + namespace infini_train::core::cuda { class CudaGuardImpl : public DeviceGuardImpl { diff --git a/infini_train/src/core/device_guard.cc b/infini_train/src/core/device_guard.cc index 714ab6d5..9b6d04ec 100644 --- a/infini_train/src/core/device_guard.cc +++ b/infini_train/src/core/device_guard.cc @@ -8,7 +8,6 @@ #include "infini_train/include/core/blas_handle.h" #include "infini_train/include/core/stream.h" -#include "infini_train/src/core/cpu/cpu_guard.h" namespace infini_train::core { @@ -71,6 +70,7 @@ std::pair DeviceGuardImpl::GetMemPoolPeakMB(Device device) const DeviceGuard::DeviceGuard(Device device) : impl_(GetDeviceGuardImpl(device.type())) { original_device_ = impl_->GetDevice(); impl_->SetDevice(device); + current_device_ = device; } void DeviceGuard::SetDevice(Device device) { @@ -88,10 +88,6 @@ Device DeviceGuard::original_device() const { return original_device_; } DeviceGuard::~DeviceGuard() { impl_->SetDevice(original_device_); } // DeviceGuardImplRegistry -DeviceGuardImplRegistry::DeviceGuardImplRegistry() { - Register(Device::DeviceType::kCPU, std::make_unique()); -} - DeviceGuardImplRegistry &DeviceGuardImplRegistry::Instance() { static DeviceGuardImplRegistry instance; return instance; diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index d18f6320..7b64c2a5 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -5,6 +5,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index e4698582..b048d026 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -7,6 +7,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index b0f239f0..1fa8face 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -10,6 +10,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/cross_entropy.cu b/infini_train/src/kernels/cuda/cross_entropy.cu index 333f8e2b..e21900f5 100644 --- a/infini_train/src/kernels/cuda/cross_entropy.cu +++ b/infini_train/src/kernels/cuda/cross_entropy.cu @@ -11,6 +11,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 2ebba200..1b9fe9eb 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -7,6 +7,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/embedding.cu b/infini_train/src/kernels/cuda/embedding.cu index ec025098..b65699f1 100644 --- a/infini_train/src/kernels/cuda/embedding.cu +++ b/infini_train/src/kernels/cuda/embedding.cu @@ -4,6 +4,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index a278a93a..8944e61a 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -5,6 +5,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 382e3d5a..d318465c 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -4,6 +4,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/layernorm.cu b/infini_train/src/kernels/cuda/layernorm.cu index 47e5654e..c899f8ff 100644 --- a/infini_train/src/kernels/cuda/layernorm.cu +++ b/infini_train/src/kernels/cuda/layernorm.cu @@ -6,6 +6,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index dbbc8cce..2f8b93d8 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -11,6 +11,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_blas_handle.h" #include "infini_train/src/core/cuda/cuda_stream.h" diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index d595024a..64303708 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -10,6 +10,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_blas_handle.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/reduction.cu b/infini_train/src/kernels/cuda/reduction.cu index ac5e6d20..c54704b2 100644 --- a/infini_train/src/kernels/cuda/reduction.cu +++ b/infini_train/src/kernels/cuda/reduction.cu @@ -6,6 +6,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 933943af..29a8f1ae 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -7,6 +7,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/softmax.cu b/infini_train/src/kernels/cuda/softmax.cu index 7c453ab2..943831e1 100644 --- a/infini_train/src/kernels/cuda/softmax.cu +++ b/infini_train/src/kernels/cuda/softmax.cu @@ -11,6 +11,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index ee887b7e..ec258976 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -6,6 +6,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index d544f11d..56067cb8 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -10,6 +10,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index d9284c46..62d316b1 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -9,6 +9,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu index e023cbbd..a626eb63 100644 --- a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu +++ b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu @@ -7,6 +7,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 9cd56f7a..27f473c2 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -19,7 +19,6 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/tensor.h" -#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::init { namespace { @@ -187,73 +186,49 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { return tensor; } -#define CASE(DATA_TYPE, TYPE) \ +#define ARANGE_CASE(DATA_TYPE, TYPE) \ case DATA_TYPE: { \ std::vector buffer(num_elements); \ std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE)); \ - break; \ - } -#define CUDA_CASE(DATA_TYPE, TYPE) \ - case DATA_TYPE: { \ - std::vector buffer(num_elements); \ - std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), cudaMemcpyHostToDevice, \ - dynamic_cast( \ - core::GetDeviceGuardImpl(device.type())->GetStream(device)) \ - ->cuda_stream()); \ + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind, stream); \ break; \ } std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Device device) { - int64_t num_elements = end - start; + const int64_t num_elements = end - start; auto tensor = std::make_shared(std::vector{num_elements}, dtype, device); + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + + const core::MemcpyKind kind = device.IsCPU() ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D; + core::Stream *stream = impl->GetStream(device); + + switch (dtype) { + ARANGE_CASE(DataType::kUINT8, uint8_t) + ARANGE_CASE(DataType::kINT8, int8_t) + ARANGE_CASE(DataType::kUINT16, uint16_t) + ARANGE_CASE(DataType::kINT16, int16_t) + ARANGE_CASE(DataType::kUINT32, uint32_t) + ARANGE_CASE(DataType::kINT32, int32_t) + ARANGE_CASE(DataType::kUINT64, uint64_t) + ARANGE_CASE(DataType::kINT64, int64_t) - if (device.IsCPU()) { - switch (dtype) { - CASE(DataType::kUINT8, uint8_t) - CASE(DataType::kINT8, int8_t) - CASE(DataType::kUINT16, uint16_t) - CASE(DataType::kINT16, int16_t) - CASE(DataType::kUINT32, uint32_t) - CASE(DataType::kINT32, int32_t) - CASE(DataType::kUINT64, uint64_t) - CASE(DataType::kINT64, int64_t) - // CASE(DataType::kBFLOAT16, bf16) - // CASE(DataType::kFLOAT16, fp16) - CASE(DataType::kFLOAT32, float) - CASE(DataType::kFLOAT64, double) - default: - LOG(FATAL) << "Unsupported data type: " << static_cast(dtype); - break; - } - } else { #ifdef USE_CUDA - switch (dtype) { - CUDA_CASE(DataType::kUINT8, uint8_t) - CUDA_CASE(DataType::kINT8, int8_t) - CUDA_CASE(DataType::kUINT16, uint16_t) - CUDA_CASE(DataType::kINT16, int16_t) - CUDA_CASE(DataType::kUINT32, uint32_t) - CUDA_CASE(DataType::kINT32, int32_t) - CUDA_CASE(DataType::kUINT64, uint64_t) - CUDA_CASE(DataType::kINT64, int64_t) - CUDA_CASE(DataType::kBFLOAT16, nv_bfloat16) - CUDA_CASE(DataType::kFLOAT16, half) - CUDA_CASE(DataType::kFLOAT32, float) - CUDA_CASE(DataType::kFLOAT64, double) - default: - LOG(FATAL) << "Unsupported data type: " << static_cast(dtype); - break; - } -#else - LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); + ARANGE_CASE(DataType::kBFLOAT16, nv_bfloat16) + ARANGE_CASE(DataType::kFLOAT16, half) #endif + + ARANGE_CASE(DataType::kFLOAT32, float) + ARANGE_CASE(DataType::kFLOAT64, double) + + default: + LOG(FATAL) << "Unsupported data type: " << static_cast(dtype); + break; } + return tensor; } -#undef CASE -#undef CUDA_CASE +#undef ARANGE_CASE } // namespace infini_train::nn::init diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index d17e3d9a..18e335db 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -27,8 +27,8 @@ class ISend : public autograd::Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: - Device target_device_ = Device(); - Device input_device_ = Device(); + Device target_device_; + Device input_device_; int cur_rank_ = -1; int peer_rank_ = -1; const std::vector> &shapes_; @@ -49,8 +49,8 @@ class IRecv : public autograd::Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: - Device src_device_ = Device(); - Device cur_device_ = Device(); + Device src_device_; + Device cur_device_; int cur_rank_ = -1; int peer_rank_ = -1; }; diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index cb1e07b8..d9508592 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -24,6 +24,7 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train { diff --git a/infini_train/src/nn/parallel/work.cc b/infini_train/src/nn/parallel/work.cc index 57018258..8c57070b 100644 --- a/infini_train/src/nn/parallel/work.cc +++ b/infini_train/src/nn/parallel/work.cc @@ -7,6 +7,7 @@ #endif #include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::parallel { diff --git a/infini_train/src/profiler.cc b/infini_train/src/profiler.cc index cbcb470a..5235bd17 100644 --- a/infini_train/src/profiler.cc +++ b/infini_train/src/profiler.cc @@ -16,6 +16,7 @@ #endif #include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" + #include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train { diff --git a/scripts/test_config.json b/scripts/test_config.json index 84f4fedd..5659b516 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -6,7 +6,8 @@ "LLAMA3_INPUT_BIN": "../../data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin", "LLAMA3_LLMC_FILEPATH": "../../data/llmc/llama3/llama3.2_1B_fp32.bin", "PROFILE_LOG_DIR": "./profile_logs", - "LOG_DIR": "logs" + "LOG_DIR": "./logs", + "COMPARE_LOG_DIR": "" }, "builds": [ { From 4f0fa84956958a70c294e97454ed0c758bb7facc Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 11 Feb 2026 03:31:49 +0000 Subject: [PATCH 7/7] refactor: enforce strict backend contracts for DeviceGuardImpl - Drop legacy hardware-specific branching - Convert DeviceGuardImpl base methods to fatal-only fallbacks - Explicitly implement supported CPU runtime behaviors - Validate CUDA device type and index bounds in CudaGuardImpl - Widen DeviceCount return type to prevent truncation --- example/gpt2/main.cc | 8 +- example/llama3/main.cc | 8 +- infini_train/include/core/blas_handle.h | 2 +- infini_train/include/core/device_guard.h | 17 ++++- infini_train/include/core/stream.h | 1 + infini_train/src/core/cpu/cpu_guard_impl.cc | 73 ++++++++++++++++++- infini_train/src/core/cpu/cpu_guard_impl.h | 41 +++++++++-- .../src/core/cuda/cuda_blas_handle.cc | 5 +- infini_train/src/core/cuda/cuda_blas_handle.h | 8 ++ infini_train/src/core/cuda/cuda_guard_impl.cc | 51 ++++++++++--- infini_train/src/core/cuda/cuda_guard_impl.h | 6 +- infini_train/src/core/cuda/cuda_stream.cc | 5 ++ infini_train/src/core/cuda/cuda_stream.h | 8 ++ infini_train/src/core/device_guard.cc | 50 ++++++------- infini_train/src/utils/precision_checker.cc | 7 +- 15 files changed, 218 insertions(+), 72 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 9d5d0313..3dfeadd3 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -285,9 +285,7 @@ void Train(const nn::parallel::Rank &rank) { const bool last_step = step == FLAGS_num_iteration; - if (device.IsCUDA()) { - impl->ResetMemPoolHighWatermarks(device); - } + impl->ResetMemPoolHighWatermarks(device); const auto iter_start = std::chrono::high_resolution_clock::now(); @@ -378,9 +376,7 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; - if (device.IsCUDA()) { - std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - } + std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 31354109..a7de81ff 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -261,9 +261,7 @@ void Train(const nn::parallel::Rank &rank) { const bool last_step = step == FLAGS_num_iteration; - if (device.IsCUDA()) { - impl->ResetMemPoolHighWatermarks(device); - } + impl->ResetMemPoolHighWatermarks(device); const auto iter_start = std::chrono::high_resolution_clock::now(); @@ -354,9 +352,7 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; - if (device.IsCUDA()) { - std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - } + std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", diff --git a/infini_train/include/core/blas_handle.h b/infini_train/include/core/blas_handle.h index 56b058ce..4b20e11c 100644 --- a/infini_train/include/core/blas_handle.h +++ b/infini_train/include/core/blas_handle.h @@ -4,7 +4,7 @@ namespace infini_train::core { class BlasHandle { public: - BlasHandle(){}; + BlasHandle() = default; virtual ~BlasHandle() = default; }; diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h index de43b376..36945ea1 100644 --- a/infini_train/include/core/device_guard.h +++ b/infini_train/include/core/device_guard.h @@ -19,6 +19,21 @@ enum class MemcpyKind : int8_t { kInvalid = -1, }; +inline const char *MemcpyKindToString(MemcpyKind k) { + switch (k) { + case MemcpyKind::kH2D: + return "kH2D"; + case MemcpyKind::kD2H: + return "kD2H"; + case MemcpyKind::kD2D: + return "kD2D"; + case MemcpyKind::kInvalid: + return "kInvalid"; + default: + return "Unknown"; + } +} + // // ---------------------------------------------------------------------------- // DeviceGuardImpl: Backend-specific device/stream/memory/BLAS implementation @@ -56,7 +71,7 @@ class DeviceGuardImpl { virtual void SetDevice(Device device) const; - virtual int8_t DeviceCount() const; + virtual int DeviceCount() const; virtual Device::DeviceType Type() const = 0; diff --git a/infini_train/include/core/stream.h b/infini_train/include/core/stream.h index 190298f6..db8aa2e0 100644 --- a/infini_train/include/core/stream.h +++ b/infini_train/include/core/stream.h @@ -4,6 +4,7 @@ namespace infini_train::core { class Stream { public: + Stream() = default; virtual ~Stream() = default; }; diff --git a/infini_train/src/core/cpu/cpu_guard_impl.cc b/infini_train/src/core/cpu/cpu_guard_impl.cc index 9e3d9ec5..45972258 100644 --- a/infini_train/src/core/cpu/cpu_guard_impl.cc +++ b/infini_train/src/core/cpu/cpu_guard_impl.cc @@ -2,6 +2,8 @@ #include #include +#include +#include #include "glog/logging.h" @@ -15,17 +17,82 @@ Device CpuGuardImpl::GetDevice() const { return Device(Device::DeviceType::kCPU, Device::DeviceType CpuGuardImpl::Type() const { return Device::DeviceType::kCPU; } +void CpuGuardImpl::SetDevice(Device device) const { + // No-op for CPU + CHECK(device.type() == Device::DeviceType::kCPU); + LOG(WARNING) << "CpuGuardImpl::SetDevice is not supported. " + "The call is ignored."; +} + +int CpuGuardImpl::DeviceCount() const { return 1; } + +Stream *CpuGuardImpl::GetStream(Device device) const { + CHECK(device.type() == Device::DeviceType::kCPU); + LOG(WARNING) << "CpuGuardImpl::GetStream is not supported. " + "Return nullptr."; + return nullptr; +} + +void CpuGuardImpl::SynchronizeDevice(Device device) const { + // No-op for CPU + CHECK(device.type() == Device::DeviceType::kCPU); + LOG(WARNING) << "CpuGuardImpl::SynchronizeDevice is not supported. " + "The call is ignored."; +} + +void CpuGuardImpl::SynchronizeStream(Stream *) const { + // No-op for CPU + LOG(WARNING) << "CpuGuardImpl::SynchronizeStream is not supported. " + "The call is ignored."; +} + +BlasHandle *CpuGuardImpl::GetBlasHandle(Device device) const { + CHECK(device.type() == Device::DeviceType::kCPU); + LOG(WARNING) << "CpuGuardImpl::GetBlasHandle is not supported. " + "Return nullptr."; + return nullptr; +} + void CpuGuardImpl::Malloc(void **dev_ptr, size_t size) { *dev_ptr = std::malloc(size); } void CpuGuardImpl::Free(void *dev_ptr) { std::free(dev_ptr); } -void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { - CHECK(kind == MemcpyKind::kD2D) << "CpuGuardImpl::Memcpy only supports kD2D (host-to-host) memcpy, " - << "but got MemcpyKind=" << static_cast(kind); +void CpuGuardImpl::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { + LOG(WARNING) << "CpuGuardImpl::MallocAsync is not supported. Falling back to blocking Malloc()"; + Malloc(dev_ptr, size); +} + +void CpuGuardImpl::FreeAsync(void *dev_ptr, Stream *stream) { + LOG(WARNING) << "CpuGuardImpl::FreeAsync is not supported. Falling back to blocking Free()"; + Free(dev_ptr); +} +void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { + CHECK(kind == MemcpyKind::kD2D) << std::format("CpuGuardImpl::Memcpy only supports kD2D (host-to-host) memcpy, " + "but got MemcpyKind={}", + MemcpyKindToString(kind)); std::memcpy(dst, src, count); } +void CpuGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { + LOG(WARNING) << "CpuGuardImpl::MemcpyAsync is not supported. Falling back to blocking Memcpy()"; + Memcpy(dst, src, count, kind); +} + +void CpuGuardImpl::ResetMemPoolHighWatermarks(Device device) const { + // No-op for CPU + CHECK(device.type() == Device::DeviceType::kCPU); + LOG(WARNING) << "CpuGuardImpl::ResetMemPoolHighWatermarks is not supported. " + "The call is ignored."; +} + +std::pair CpuGuardImpl::GetMemPoolPeakMB(Device device) const { + CHECK(device.type() == Device::DeviceType::kCPU); + LOG(WARNING) << "CpuGuardImpl::GetMemPoolPeakMB is not supported. " + "Return {0, 0}."; + return {0, 0}; +} + INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCPU, CpuGuardImpl) } // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cpu/cpu_guard_impl.h b/infini_train/src/core/cpu/cpu_guard_impl.h index 3b6ac71f..83d0da19 100644 --- a/infini_train/src/core/cpu/cpu_guard_impl.h +++ b/infini_train/src/core/cpu/cpu_guard_impl.h @@ -1,22 +1,51 @@ #pragma once +#include + #include "infini_train/include/core/device_guard.h" namespace infini_train::core::cpu { -class CpuGuardImpl : public DeviceGuardImpl { +class CpuGuardImpl final : public DeviceGuardImpl { public: CpuGuardImpl(); - Device GetDevice() const; + // Device management + Device GetDevice() const override; + + void SetDevice(Device device) const override; // CPU: no-op + + int DeviceCount() const override; // CPU: 1 + + Device::DeviceType Type() const override; + + // Stream management (explicitly unsupported for now) + Stream *GetStream(Device device) const override; + + // Synchronization + void SynchronizeDevice(Device device) const override; // CPU: no-op + + void SynchronizeStream(Stream *stream) const override; // CPU: no-op + + // BLAS handle (explicitly unsupported for now) + BlasHandle *GetBlasHandle(Device device) const override; + + // Memory ops (async ops falls back to blocking explicitly in CPU impl) + void Malloc(void **dev_ptr, size_t size) override; + + void Free(void *dev_ptr) override; + + void MallocAsync(void **dev_ptr, size_t size, Stream *stream) override; + + void FreeAsync(void *dev_ptr, Stream *stream) override; - Device::DeviceType Type() const; + void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) override; - void Malloc(void **dev_ptr, size_t size); + void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) override; - void Free(void *dev_ptr); + void ResetMemPoolHighWatermarks(Device device) const override; // CPU: no-op - void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind); + std::pair GetMemPoolPeakMB(Device device) const override; // CPU: {0, 0} }; } // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cuda/cuda_blas_handle.cc b/infini_train/src/core/cuda/cuda_blas_handle.cc index 36da1eab..d485133c 100644 --- a/infini_train/src/core/cuda/cuda_blas_handle.cc +++ b/infini_train/src/core/cuda/cuda_blas_handle.cc @@ -1,4 +1,3 @@ - #include "infini_train/src/core/cuda/cuda_blas_handle.h" #include "infini_train/include/common/cuda/common_cuda.h" @@ -12,6 +11,10 @@ CudaBlasHandle::CudaBlasHandle(Stream *stream) { CUBLAS_CHECK(cublasSetStream(cublas_handle_, dynamic_cast(stream)->cuda_stream())); } +CudaBlasHandle::~CudaBlasHandle() { + // Do nothing. +} + cublasHandle_t CudaBlasHandle::cublas_handle() const { return cublas_handle_; } } // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_blas_handle.h b/infini_train/src/core/cuda/cuda_blas_handle.h index 53678916..44569f28 100644 --- a/infini_train/src/core/cuda/cuda_blas_handle.h +++ b/infini_train/src/core/cuda/cuda_blas_handle.h @@ -14,6 +14,14 @@ class CudaBlasHandle : public BlasHandle { public: explicit CudaBlasHandle(Stream *stream); + // NOTE(dcj): + // The CudaBlasHandle are "leaked": they are created but never destroyed because the + // destruction of global variables could happen after the CUDA runtime has + // already been destroyed and thus invoking cudaStreamDestroy could lead to a + // crash. It's likely an issue in CUDA, but to be safe - let's just "forget" + // the destruction. + ~CudaBlasHandle() override; + cublasHandle_t cublas_handle() const; private: diff --git a/infini_train/src/core/cuda/cuda_guard_impl.cc b/infini_train/src/core/cuda/cuda_guard_impl.cc index 0a4a58ef..f6c42a6d 100644 --- a/infini_train/src/core/cuda/cuda_guard_impl.cc +++ b/infini_train/src/core/cuda/cuda_guard_impl.cc @@ -1,7 +1,6 @@ #include "infini_train/src/core/cuda/cuda_guard_impl.h" #include -#include #include #include @@ -22,9 +21,18 @@ static std::array, kMaxGpus> cuda_blas_handles; static std::array device_stream_flags; static std::array device_handle_flags; + +inline void CheckCudaDevice(Device device) { + CHECK(device.type() == Device::DeviceType::kCUDA) << std::format( + "CudaGuardImpl expects CUDA device, but got type={} index={}", static_cast(device.type()), device.index()); + const int idx = device.index(); + CHECK(idx >= 0 && idx < kMaxGpus) << std::format("CUDA device index {} out of cache range [0, {}).", idx, kMaxGpus); +} } // namespace void CudaGuardImpl::InitSingleStream(Device device) { + CheckCudaDevice(device); + int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); CUDA_CHECK(cudaSetDevice(device.index())); @@ -35,6 +43,8 @@ void CudaGuardImpl::InitSingleStream(Device device) { } void CudaGuardImpl::InitSingleHandle(Device device) { + CheckCudaDevice(device); + int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); CUDA_CHECK(cudaSetDevice(device.index())); @@ -55,9 +65,12 @@ Device CudaGuardImpl::GetDevice() const { return Device(Device::DeviceType::kCUDA, current_device); } -void CudaGuardImpl::SetDevice(Device device) const { CUDA_CHECK(cudaSetDevice(device.index())); } +void CudaGuardImpl::SetDevice(Device device) const { + CheckCudaDevice(device); + CUDA_CHECK(cudaSetDevice(device.index())); +} -int8_t CudaGuardImpl::DeviceCount() const { +int CudaGuardImpl::DeviceCount() const { int device_count = 0; CUDA_DRIVER_CHECK(cuDeviceGetCount(&device_count)); return device_count; @@ -67,6 +80,10 @@ Device::DeviceType CudaGuardImpl::Type() const { return Device::DeviceType::kCUD // stream Stream *CudaGuardImpl::GetStream(Device device) const { + CheckCudaDevice(device); + // FIXME(dcj): call_once is process-scoped and assumes single initialization. + // This can be problematic if the CUDA backend is initialized multiple + // times within the same process (e.g. in unit tests). std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device); return cuda_streams.at(device.index()).get(); } @@ -85,6 +102,7 @@ void CudaGuardImpl::SynchronizeDevice(Device device) const { // blas BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const { + CheckCudaDevice(device); std::call_once(device_handle_flags.at(device.index()), InitSingleHandle, device); return cuda_blas_handles.at(device.index()).get(); } @@ -110,24 +128,32 @@ void CudaGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind } else if (kind == MemcpyKind::kD2D) { CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToDevice)); } else { - LOG(FATAL) << "Invalid MemcpyKind"; + LOG(FATAL) << std::format("CudaGuardImpl::Memcpy got invalid MemcpyKind={}", MemcpyKindToString(kind)); } } void CudaGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { cudaStream_t cuda_stream = dynamic_cast(stream)->cuda_stream(); - if (kind == MemcpyKind::kH2D) { + + switch (kind) { + case MemcpyKind::kH2D: CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyHostToDevice, cuda_stream)); - } else if (kind == MemcpyKind::kD2H) { + break; + case MemcpyKind::kD2H: CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToHost, cuda_stream)); - } else if (kind == MemcpyKind::kD2D) { + break; + case MemcpyKind::kD2D: CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, cuda_stream)); - } else { - LOG(FATAL) << "Invalid MemcpyKind"; + break; + default: + LOG(FATAL) << std::format("CudaGuardImpl::MemcpyAsync got invalid MemcpyKind={}", MemcpyKindToString(kind)); } } void CudaGuardImpl::ResetMemPoolHighWatermarks(Device device) const { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + SetDevice(device); cudaMemPool_t pool; CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, device.index())); @@ -136,9 +162,14 @@ void CudaGuardImpl::ResetMemPoolHighWatermarks(Device device) const { // High watermark can only be reset to zero; non-zero is illegal. CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &zero)); CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &zero)); + + CUDA_CHECK(cudaSetDevice(current_device)); } std::pair CudaGuardImpl::GetMemPoolPeakMB(Device device) const { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + SetDevice(device); cudaMemPool_t pool; CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, device.index())); @@ -149,6 +180,8 @@ std::pair CudaGuardImpl::GetMemPoolPeakMB(Device device) const { cuuint64_t reserved = 0; CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &reserved)); + CUDA_CHECK(cudaSetDevice(current_device)); + return std::make_pair(static_cast(used / kBytesPerMB), static_cast(reserved / kBytesPerMB)); } diff --git a/infini_train/src/core/cuda/cuda_guard_impl.h b/infini_train/src/core/cuda/cuda_guard_impl.h index af84570b..7a64dbe4 100644 --- a/infini_train/src/core/cuda/cuda_guard_impl.h +++ b/infini_train/src/core/cuda/cuda_guard_impl.h @@ -1,7 +1,5 @@ #pragma once -#include - #include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" @@ -12,7 +10,7 @@ class BlasHandle; namespace infini_train::core::cuda { -class CudaGuardImpl : public DeviceGuardImpl { +class CudaGuardImpl final : public DeviceGuardImpl { public: static void InitSingleStream(Device device); @@ -25,7 +23,7 @@ class CudaGuardImpl : public DeviceGuardImpl { void SetDevice(Device device) const override; - int8_t DeviceCount() const override; + int DeviceCount() const override; Device::DeviceType Type() const override; diff --git a/infini_train/src/core/cuda/cuda_stream.cc b/infini_train/src/core/cuda/cuda_stream.cc index 82d04566..0d97bae9 100644 --- a/infini_train/src/core/cuda/cuda_stream.cc +++ b/infini_train/src/core/cuda/cuda_stream.cc @@ -5,8 +5,13 @@ #include "infini_train/include/common/cuda/common_cuda.h" namespace infini_train::core::cuda { + CudaStream::CudaStream() { CUDA_CHECK(cudaStreamCreate(&stream_)); } +CudaStream::~CudaStream() { + // Do nothing. +} + cudaStream_t CudaStream::cuda_stream() const { return stream_; } } // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_stream.h b/infini_train/src/core/cuda/cuda_stream.h index c5252097..b7eb834c 100644 --- a/infini_train/src/core/cuda/cuda_stream.h +++ b/infini_train/src/core/cuda/cuda_stream.h @@ -10,6 +10,14 @@ class CudaStream : public Stream { public: CudaStream(); + // NOTE(dcj): + // The CudaStream are "leaked": they are created but never destroyed because the + // destruction of global variables could happen after the CUDA runtime has + // already been destroyed and thus invoking cudaStreamDestroy could lead to a + // crash. It's likely an issue in CUDA, but to be safe - let's just "forget" + // the destruction. + ~CudaStream() override; + cudaStream_t cuda_stream() const; private: diff --git a/infini_train/src/core/device_guard.cc b/infini_train/src/core/device_guard.cc index 9b6d04ec..0f55b674 100644 --- a/infini_train/src/core/device_guard.cc +++ b/infini_train/src/core/device_guard.cc @@ -11,59 +11,51 @@ namespace infini_train::core { -// DeviceGuardImpl -void DeviceGuardImpl::SetDevice(Device device) const { - LOG(WARNING) << std::format("SetDevice is not supported for device type {} (index {}). " - "The call is ignored.", - static_cast(device.type()), device.index()); -} +// DeviceGuardImpl (base fallback: FATAL only) +void DeviceGuardImpl::SetDevice(Device) const { LOG(FATAL) << "DeviceGuardImpl::SetDevice is not implemented."; } -int8_t DeviceGuardImpl::DeviceCount() const { return -1; } +int DeviceGuardImpl::DeviceCount() const { + LOG(FATAL) << "DeviceGuardImpl::DeviceCount is not implemented."; + return -1; // unreachable +} -Stream *DeviceGuardImpl::GetStream(Device) const { return nullptr; } +Stream *DeviceGuardImpl::GetStream(Device) const { + LOG(FATAL) << "DeviceGuardImpl::GetStream is not implemented."; + return nullptr; // unreachable +} -void DeviceGuardImpl::SynchronizeDevice(Device device) const { - LOG(WARNING) << std::format("SynchronizeDevice is not supported for this device. " - "The call is ignored.", - static_cast(device.type()), device.index()); +void DeviceGuardImpl::SynchronizeDevice(Device) const { + LOG(FATAL) << "DeviceGuardImpl::SynchronizeDevice is not implemented."; } void DeviceGuardImpl::SynchronizeStream(Stream *) const { - LOG(WARNING) << "SynchronizeStream is not supported for this device. " - "The call is ignored."; + LOG(FATAL) << "DeviceGuardImpl::SynchronizeStream is not implemented."; } BlasHandle *DeviceGuardImpl::GetBlasHandle(Device device) const { - LOG(FATAL) << std::format("GetBlasHandle is not supported for device type {} (index {}). ", - static_cast(device.type()), device.index()); + LOG(FATAL) << "DeviceGuardImpl::GetBlasHandle is not implemented."; + return nullptr; // unreachable } void DeviceGuardImpl::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { - LOG(WARNING) << "MallocAsync is not supported on this device. Falling back to blocking Malloc()"; - Malloc(dev_ptr, size); + LOG(FATAL) << "DeviceGuardImpl::MallocAsync is not implemented."; } void DeviceGuardImpl::FreeAsync(void *dev_ptr, Stream *stream) { - LOG(WARNING) << "FreeAsync is not supported on this device. Falling back to blocking Free()"; - Free(dev_ptr); + LOG(FATAL) << "DeviceGuardImpl::FreeAsync is not implemented"; } void DeviceGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { - LOG(WARNING) << "MemcpyAsync is not supported on this device. Falling back to blocking Memcpy()"; - Memcpy(dst, src, count, kind); + LOG(FATAL) << "DeviceGuardImpl::MemcpyAsync is not implemented"; } void DeviceGuardImpl::ResetMemPoolHighWatermarks(Device device) const { - LOG(WARNING) << std::format("ResetMemPoolHighWatermarks is not supported for device type {} (index {}). " - "The call is ignored.", - static_cast(device.type()), device.index()); + LOG(FATAL) << "DeviceGuardImpl::ResetMemPoolHighWatermarks is not implemented."; } std::pair DeviceGuardImpl::GetMemPoolPeakMB(Device device) const { - LOG(WARNING) << std::format("GetMemPoolPeakMB is not supported for device type {} (index {}). " - "Returning {{0, 0}}.", - static_cast(device.type()), device.index()); - return {0, 0}; + LOG(FATAL) << "DeviceGuardImpl::GetMemPoolPeakMB is not implemented for device type {} (index {})."; + return {0, 0}; // unreachable } // DeviceGuard diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index f1825f14..d2cbd16a 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -292,12 +292,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string auto &tensor = tensors[i]; // Copy tensor to CPU if it's on GPU - std::shared_ptr cpu_tensor; - if (tensor->GetDevice().IsCUDA()) { - cpu_tensor = std::make_shared(tensor->To(Device())); - } else { - cpu_tensor = tensor; - } + std::shared_ptr cpu_tensor = std::make_shared(tensor->To(Device())); const float *float_data = static_cast(cpu_tensor->DataPtr()); const size_t byte_size = cpu_tensor->SizeInBytes();