Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/nexus-api/nxs.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ enum _nxs_buffer_settings {
};
typedef enum _nxs_buffer_settings nxs_buffer_settings;

/* ENUM _nxs_buffer_transfer */
/*
* NXS_BufferDeviceToHost:
* - Copy buffer from device to host
* NXS_BufferHostToDevice:
* - Copy buffer from host to device
*/
enum _nxs_buffer_transfer {
NXS_BufferDeviceToHost = 0,
NXS_BufferHostToDevice = 1,
};

/********************************************************************************************************/
/* Constants */
Expand Down
2 changes: 1 addition & 1 deletion include/nexus/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Buffer : public Object<detail::BufferImpl> {

Buffer getLocal() const;

nxs_status copy(void *_hostBuf);
nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost);
};

typedef Objects<Buffer> Buffers;
Expand Down
9 changes: 6 additions & 3 deletions plugins/cuda/cuda_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,12 @@ extern "C" nxs_status NXS_API_CALL nxsCopyBuffer(nxs_int buffer_id,
auto buffer = rt->get<rt::Buffer>(buffer_id);
if (!buffer) return NXS_InvalidBuffer;
if (!host_ptr) return NXS_InvalidHostPtr;

CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, host_ptr, buffer->get(),
buffer->size(), cudaMemcpyDeviceToHost);
if (copy_settings == NXS_BufferDeviceToHost)
CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, host_ptr, buffer->get(),
buffer->size(), cudaMemcpyDeviceToHost);
else
CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, buffer->get(), host_ptr,
buffer->size(), cudaMemcpyHostToDevice);
return NXS_Success;
}

Expand Down
2 changes: 1 addition & 1 deletion src/_buffer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BufferImpl : public Impl {
void setData(void *_data) { data = _data; }

Buffer getLocal();
nxs_status copyData(void *_hostBuf) const;
nxs_status copyData(void *_hostBuf, nxs_uint direction) const;

std::string print() const;

Expand Down
6 changes: 3 additions & 3 deletions src/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ Buffer detail::BufferImpl::getLocal() {
return Buffer();
}

nxs_status detail::BufferImpl::copyData(void *_hostBuf) const {
nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) const {
if (nxs_valid_id(getDeviceId())) {
NEXUS_LOG(NXS_LOG_NOTE, "copyData: from device: ", getSize());
auto *rt = getParentOfType<RuntimeImpl>();
return (nxs_status)rt->runAPIFunction<NF_nxsCopyBuffer>(getId(), _hostBuf,
0);
direction);
}
NEXUS_LOG(NXS_LOG_NOTE, "copyData: from host: ", getSize());
memcpy(_hostBuf, getData(), getSize());
Expand Down Expand Up @@ -176,4 +176,4 @@ Buffer Buffer::getLocal() const {
return get()->getLocal();
}

nxs_status Buffer::copy(void *_hostBuf) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf); }
nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); }
2 changes: 1 addition & 1 deletion test/cpp/test_basic_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ int test_basic_kernel(int argc, char** argv) {
auto time_ms = sched.getProp<nxs_double>(NP_ElapsedTime);
std::cout << "Elapsed time: " << time_ms << std::endl;

buf2.copy(vecResult_GPU.data());
buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost);

int i = 0;
for (auto v : vecResult_GPU) {
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ int test_graph(int argc, char** argv) {
time = sched.getProp<nxs_double>(NP_ElapsedTime);
}

buf2.copy(vecResult_GPU.data());
buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost);

std::cout << std::endl << "Test Time: " << time << std::endl;

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_kernel_catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int test_kernel_catalog(int argc, char** argv) {

sched.run(stream0);

buf2.copy(vecResult_GPU.data());
buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost);

int i = 0;
for (auto v : vecResult_GPU) {
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_multi_stream_sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ int test_multi_stream_sync(int argc, char** argv) {

evFinal.wait();

buf3.copy(vecResult_GPU.data());
buf3.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost);

int i = 0;
for (auto v : vecResult_GPU) {
Expand Down
Loading