diff --git a/include/nexus-api/nxs.h b/include/nexus-api/nxs.h index 9196feb..8056d39 100644 --- a/include/nexus-api/nxs.h +++ b/include/nexus-api/nxs.h @@ -252,6 +252,17 @@ enum _nxs_buffer_settings { }; typedef enum _nxs_buffer_settings nxs_buffer_settings; +/* ENUM _nxs_buffer_transfer */ +/* + * NXS_BufferDeviceToHost: + * - Copy buffer from device to host + * NXS_BufferHostToDevice: + * - Copy buffer from host to device + */ +enum _nxs_buffer_transfer { + NXS_BufferDeviceToHost = 0, + NXS_BufferHostToDevice = 1, +}; /********************************************************************************************************/ /* Constants */ diff --git a/include/nexus/buffer.h b/include/nexus/buffer.h index fa0d791..a6139a4 100644 --- a/include/nexus/buffer.h +++ b/include/nexus/buffer.h @@ -33,7 +33,7 @@ class Buffer : public Object { Buffer getLocal() const; - nxs_status copy(void *_hostBuf); + nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost); }; typedef Objects Buffers; diff --git a/plugins/cuda/cuda_runtime.cpp b/plugins/cuda/cuda_runtime.cpp index e1ce6a4..070c43d 100644 --- a/plugins/cuda/cuda_runtime.cpp +++ b/plugins/cuda/cuda_runtime.cpp @@ -236,9 +236,12 @@ extern "C" nxs_status NXS_API_CALL nxsCopyBuffer(nxs_int buffer_id, auto buffer = rt->get(buffer_id); if (!buffer) return NXS_InvalidBuffer; if (!host_ptr) return NXS_InvalidHostPtr; - - CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, host_ptr, buffer->get(), - buffer->size(), cudaMemcpyDeviceToHost); + if (copy_settings == NXS_BufferDeviceToHost) + CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, host_ptr, buffer->get(), + buffer->size(), cudaMemcpyDeviceToHost); + else + CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, buffer->get(), host_ptr, + buffer->size(), cudaMemcpyHostToDevice); return NXS_Success; } diff --git a/src/_buffer_impl.h b/src/_buffer_impl.h index 2b72768..640a1f8 100644 --- a/src/_buffer_impl.h +++ b/src/_buffer_impl.h @@ -28,7 +28,7 @@ class BufferImpl : public Impl { void setData(void *_data) { data = _data; } Buffer getLocal(); - nxs_status copyData(void *_hostBuf) const; + nxs_status copyData(void *_hostBuf, nxs_uint direction) const; std::string print() const; diff --git a/src/buffer.cpp b/src/buffer.cpp index d03a04f..29427c1 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -140,12 +140,12 @@ Buffer detail::BufferImpl::getLocal() { return Buffer(); } -nxs_status detail::BufferImpl::copyData(void *_hostBuf) const { +nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) const { if (nxs_valid_id(getDeviceId())) { NEXUS_LOG(NXS_LOG_NOTE, "copyData: from device: ", getSize()); auto *rt = getParentOfType(); return (nxs_status)rt->runAPIFunction(getId(), _hostBuf, - 0); + direction); } NEXUS_LOG(NXS_LOG_NOTE, "copyData: from host: ", getSize()); memcpy(_hostBuf, getData(), getSize()); @@ -176,4 +176,4 @@ Buffer Buffer::getLocal() const { return get()->getLocal(); } -nxs_status Buffer::copy(void *_hostBuf) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf); } +nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); } diff --git a/test/cpp/test_basic_kernel.cpp b/test/cpp/test_basic_kernel.cpp index bc1bf18..65fcf02 100644 --- a/test/cpp/test_basic_kernel.cpp +++ b/test/cpp/test_basic_kernel.cpp @@ -84,7 +84,7 @@ int test_basic_kernel(int argc, char** argv) { auto time_ms = sched.getProp(NP_ElapsedTime); std::cout << "Elapsed time: " << time_ms << std::endl; - buf2.copy(vecResult_GPU.data()); + buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); int i = 0; for (auto v : vecResult_GPU) { diff --git a/test/cpp/test_graph.cpp b/test/cpp/test_graph.cpp index 3617bd6..7d1f749 100644 --- a/test/cpp/test_graph.cpp +++ b/test/cpp/test_graph.cpp @@ -87,7 +87,7 @@ int test_graph(int argc, char** argv) { time = sched.getProp(NP_ElapsedTime); } - buf2.copy(vecResult_GPU.data()); + buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); std::cout << std::endl << "Test Time: " << time << std::endl; diff --git a/test/cpp/test_kernel_catalog.cpp b/test/cpp/test_kernel_catalog.cpp index 2f10262..21b1d58 100644 --- a/test/cpp/test_kernel_catalog.cpp +++ b/test/cpp/test_kernel_catalog.cpp @@ -82,7 +82,7 @@ int test_kernel_catalog(int argc, char** argv) { sched.run(stream0); - buf2.copy(vecResult_GPU.data()); + buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); int i = 0; for (auto v : vecResult_GPU) { diff --git a/test/cpp/test_multi_stream_sync.cpp b/test/cpp/test_multi_stream_sync.cpp index f91f742..c75c4ac 100644 --- a/test/cpp/test_multi_stream_sync.cpp +++ b/test/cpp/test_multi_stream_sync.cpp @@ -114,7 +114,7 @@ int test_multi_stream_sync(int argc, char** argv) { evFinal.wait(); - buf3.copy(vecResult_GPU.data()); + buf3.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); int i = 0; for (auto v : vecResult_GPU) {