diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a8897a27e2..a171cb6518 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,6 +1,6 @@ # ============================================================================= # cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # cmake-format: on # ============================================================================= @@ -163,8 +163,16 @@ set(SOURCES ) if(KvikIO_REMOTE_SUPPORT) - list(APPEND SOURCES "src/hdfs.cpp" "src/remote_handle.cpp" "src/detail/remote_handle.cpp" - "src/detail/tls.cpp" "src/detail/url.cpp" "src/shim/libcurl.cpp" + list( + APPEND + SOURCES + "src/hdfs.cpp" + "src/remote_handle.cpp" + "src/detail/remote_handle.cpp" + "src/detail/remote_handle_poll_based.cpp" + "src/detail/tls.cpp" + "src/detail/url.cpp" + "src/shim/libcurl.cpp" ) endif() diff --git a/cpp/include/kvikio/defaults.hpp b/cpp/include/kvikio/defaults.hpp index 190909c2cc..43b3d62e55 100644 --- a/cpp/include/kvikio/defaults.hpp +++ b/cpp/include/kvikio/defaults.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,9 @@ bool getenv_or(std::string_view env_var_name, bool default_val); template <> CompatMode getenv_or(std::string_view env_var_name, CompatMode default_val); +template <> +RemoteBackendType getenv_or(std::string_view env_var_name, RemoteBackendType default_val); + template <> std::vector getenv_or(std::string_view env_var_name, std::vector default_val); @@ -122,6 +126,9 @@ class defaults { bool _auto_direct_io_read; bool _auto_direct_io_write; bool _thread_pool_per_block_device; + RemoteBackendType _remote_backend; + std::size_t _remote_max_connections; + std::size_t _num_bounce_buffers; static unsigned int get_num_threads_from_env(); @@ -417,6 +424,58 @@ class defaults { * thread pool for all I/O operations. */ static void set_thread_pool_per_block_device(bool flag); + + /** + * @brief Get the current remote I/O backend type. + * + * @return The currently configured RemoteBackendType. + */ + [[nodiscard]] static RemoteBackendType remote_backend(); + + /** + * @brief Set the remote I/O backend type. + * + * Note: Changing this after creating a RemoteHandle has no effect on existing handles. The + * backend is determined at RemoteHandle construction time. + * + * @param remote_backend The backend type to use for new RemoteHandle instances. + */ + static void set_remote_backend(RemoteBackendType remote_backend); + + /** + * @brief Get the maximum number of concurrent connections for poll-based remote I/O. + * + * Only applies when using RemoteBackendType::LIBCURL_MULTI_POLL. + * + * @return Maximum number of concurrent connections. + */ + [[nodiscard]] static std::size_t remote_max_connections(); + + /** + * @brief Set the maximum number of concurrent connections for poll-based remote I/O. + * + * Only applies when using RemoteBackendType::LIBCURL_MULTI_POLL. + * + * @param remote_max_connections Maximum concurrent connections (must be positive). + */ + static void set_remote_max_connections(std::size_t remote_max_connections); + + /** + * @brief Get the number of bounce buffers used per connection for poll-based remote I/O. + * + * Controls k-way buffering: higher values allow more overlap between network I/O and H2D + * transfers but consume more pinned memory. + * + * @return Number of bounce buffers per connection. + */ + [[nodiscard]] static std::size_t num_bounce_buffers(); + + /** + * @brief Set the number of bounce buffers used per connection for poll-based remote I/O. + * + * @param num_bounce_buffers Number of bounce buffers per connection (must be positive). + */ + static void set_num_bounce_buffers(std::size_t num_bounce_buffers); }; } // namespace kvikio diff --git a/cpp/include/kvikio/detail/remote_handle.hpp b/cpp/include/kvikio/detail/remote_handle.hpp index 2e6613aeef..fc4da9b575 100644 --- a/cpp/include/kvikio/detail/remote_handle.hpp +++ b/cpp/include/kvikio/detail/remote_handle.hpp @@ -1,12 +1,52 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include + namespace kvikio::detail { +/** + * @brief Check a libcurl easy interface return code and throw on error. + * + * @param err_code The CURLcode to check. + * @exception std::runtime_error if err_code is not CURLE_OK. + */ +#define KVIKIO_CHECK_CURL_EASY(err_code) \ + kvikio::detail::check_curl_easy(err_code, __FILE__, __LINE__) + +/** + * @brief Check a libcurl multi interface return code and throw on error. + * + * @param err_code The CURLMcode to check. + * @exception std::runtime_error if err_code is not CURLM_OK. + */ +#define KVIKIO_CHECK_CURL_MULTI(err_code) \ + kvikio::detail::check_curl_multi(err_code, __FILE__, __LINE__) + +/** + * @brief Check a libcurl easy interface return code and throw on error. + * + * @param err_code The CURLcode to check. + * @param filename Source filename for error reporting. + * @param line_number Source line number for error reporting. + * @exception std::runtime_error if err_code is not CURLE_OK. + */ +void check_curl_easy(CURLcode err_code, char const* filename, int line_number); + +/** + * @brief Check a libcurl multi interface return code and throw on error. + * + * @param err_code The CURLMcode to check. + * @param filename Source filename for error reporting. + * @param line_number Source line number for error reporting. + * @exception std::runtime_error if err_code is not CURLM_OK. + */ +void check_curl_multi(CURLMcode err_code, char const* filename, int line_number); + /** * @brief Callback for `CURLOPT_WRITEFUNCTION` that copies received data into a `std::string`. * @@ -20,4 +60,13 @@ std::size_t callback_get_string_response(char* data, std::size_t size, std::size_t num_bytes, void* userdata); + +/** + * @brief Set up the range request for libcurl. Use this method when HTTP range request is supposed. + * + * @param curl A curl handle + * @param file_offset File offset + * @param size read size + */ +void setup_range_request_impl(CurlHandle& curl, std::size_t file_offset, std::size_t size); } // namespace kvikio::detail diff --git a/cpp/include/kvikio/detail/remote_handle_poll_based.hpp b/cpp/include/kvikio/detail/remote_handle_poll_based.hpp new file mode 100644 index 0000000000..e26677dbaa --- /dev/null +++ b/cpp/include/kvikio/detail/remote_handle_poll_based.hpp @@ -0,0 +1,133 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +#include +#include +#include + +namespace kvikio::detail { + +/** + * @brief Manages a rotating set of bounce buffers for overlapping network I/O with H2D transfers. + * + * This class implements k-way buffering, rotating through buffers circularly: while one buffer + * receives data from the network, previously filled buffers can be asynchronously copied to device + * memory. When all buffers have been used, the class synchronizes the CUDA stream before reusing + * buffers. + */ +class BounceBufferManager { + public: + /** + * @brief Construct a BounceBufferManager with the specified number of bounce buffers. + * + * @param num_bounce_buffers Number of bounce buffers to allocate from the pool. + */ + BounceBufferManager(std::size_t num_bounce_buffers); + + /** + * @brief Get a pointer to the current bounce buffer's data. + * + * @return Pointer to the current buffer's memory. + */ + void* data() const noexcept; + + /** + * @brief Copy data from the current bounce buffer to device memory and rotate to the next buffer. + * + * Issues an asynchronous H2D copy and advances to the next buffer. When wrapping around to buffer + * 0, synchronizes the stream to ensure all previous copies have completed before reuse. + * + * @param dst Device memory destination pointer. + * @param size Number of bytes to copy. + * @param stream CUDA stream for the asynchronous copy. + * @exception kvikio::CUfileException if size exceeds bounce buffer capacity. + */ + void copy(void* dst, std::size_t size, CUstream stream); + + private: + std::size_t _bounce_buffer_idx{}; + std::size_t _num_bounce_buffers{}; + std::vector _bounce_buffers; +}; + +/** + * @brief Context for tracking the state of a single chunked transfer. + * + * Each concurrent connection has an associated TransferContext that tracks the destination buffer, + * transfer progress, and manages optional bounce buffers for GPU destinations. + */ +struct TransferContext { + bool overflow_error{}; + bool is_host_mem{}; + char* buf{}; + CurlHandle* curl_easy_handle{}; + std::size_t chunk_size{}; + std::size_t bytes_transferred{}; + std::optional _bounce_buffer_manager; +}; + +/** + * @brief Poll-based remote file handle using libcurl's multi interface. + * + * This class provides an alternative to the thread-pool-based remote I/O by using libcurl's multi + * interface with curl_multi_poll() for managing concurrent connections. It implements chunked + * parallel downloads with k-way buffering to overlap network transfers with host-to-device memory + * copies. + * + * @note Thread safety: The pread() method is protected by a mutex, making it safe to call from + * multiple threads, though calls will be serialized. + */ +class RemoteHandlePollBased { + private: + CURLM* _multi; + std::size_t _max_connections; + std::vector> _curl_easy_handles; + std::vector _transfer_ctxs; + RemoteEndpoint* _endpoint; + mutable std::mutex _mutex; + + public: + /** + * @brief Construct a poll-based remote handle. + * + * Initializes the libcurl multi handle and creates the specified number of easy handles for + * concurrent transfers. + * + * @param endpoint Non-owning pointer to the remote endpoint. Must outlive this object. + * @param max_connections Maximum number of concurrent connections to use. + * @exception kvikio::CUfileException if task_size exceeds bounce_buffer_size. + * @exception kvikio::CUfileException if libcurl multi initialization fails. + */ + RemoteHandlePollBased(RemoteEndpoint* endpoint, std::size_t max_connections); + + /** + * @brief Destructor that cleans up libcurl multi resources. + * + * Removes all easy handles from the multi handle and performs cleanup. Errors during cleanup are + * logged but do not throw. + */ + ~RemoteHandlePollBased() noexcept; + + /** + * @brief Read data from the remote file into a buffer. + * + * Performs a parallel chunked read using multiple concurrent HTTP range requests. For device + * memory destinations, uses bounce buffers with k-way buffering to overlap network I/O with H2D + * transfers. + * + * @param buf Destination buffer (host or device memory). + * @param size Number of bytes to read. + * @param file_offset Offset in the remote file to start reading from. + * @return Number of bytes actually read. + * @exception std::overflow_error if the server returns more data than expected (may indicate the + * server doesn't support range requests). + * @exception std::runtime_error on libcurl errors. + */ + std::size_t pread(void* buf, std::size_t size, std::size_t file_offset = 0); +}; +} // namespace kvikio::detail diff --git a/cpp/include/kvikio/remote_backend_type.hpp b/cpp/include/kvikio/remote_backend_type.hpp new file mode 100644 index 0000000000..95bba088e2 --- /dev/null +++ b/cpp/include/kvikio/remote_backend_type.hpp @@ -0,0 +1,27 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +namespace kvikio { + +/** + * @brief Enum representing the backend implementation for remote file I/O operations. + * + * KvikIO supports multiple libcurl-based backends for fetching data from remote endpoints (S3, + * HTTP, etc.). Each backend has different performance characteristics. + */ +enum class RemoteBackendType : uint8_t { + LIBCURL_EASY, ///< Use libcurl's easy interface with a thread pool for parallelism. Each chunk is + ///< fetched by a separate thread using blocking curl_easy_perform() calls. This is + ///< the default backend. + LIBCURL_MULTI_POLL, ///< Use libcurl's multi interface with poll-based concurrent transfers. A + ///< single call manages multiple concurrent connections using + ///< curl_multi_poll(), with k-way buffering to overlap network I/O with + ///< host-to-device transfers. +}; + +} // namespace kvikio diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 0b0808c45e..5c4a78029e 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -291,6 +292,11 @@ class S3EndpointWithPresignedUrl : public RemoteEndpoint { static bool is_url_valid(std::string const& url) noexcept; }; +// Forward declaration +namespace detail { +class RemoteHandlePollBased; +} + /** * @brief Handle of remote file. */ @@ -298,6 +304,8 @@ class RemoteHandle { private: std::unique_ptr _endpoint; std::size_t _nbytes; + std::unique_ptr _poll_handle; + RemoteBackendType _remote_backend_type; public: /** @@ -400,8 +408,9 @@ class RemoteHandle { RemoteHandle(std::unique_ptr endpoint); // A remote handle is moveable but not copyable. - RemoteHandle(RemoteHandle&& o) = default; - RemoteHandle& operator=(RemoteHandle&& o) = default; + ~RemoteHandle(); + RemoteHandle(RemoteHandle&& o); + RemoteHandle& operator=(RemoteHandle&& o); RemoteHandle(RemoteHandle const&) = delete; RemoteHandle& operator=(RemoteHandle const&) = delete; diff --git a/cpp/src/defaults.cpp b/cpp/src/defaults.cpp index 841e7314d3..d8ff218957 100644 --- a/cpp/src/defaults.cpp +++ b/cpp/src/defaults.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -16,8 +17,8 @@ #include #include #include +#include #include -#include namespace kvikio { template <> @@ -63,6 +64,26 @@ CompatMode getenv_or(std::string_view env_var_name, CompatMode default_val) return detail::parse_compat_mode_str(env_val); } +template <> +RemoteBackendType getenv_or(std::string_view env_var_name, RemoteBackendType default_val) +{ + KVIKIO_NVTX_FUNC_RANGE(); + auto* env_val = std::getenv(env_var_name.data()); + if (env_val == nullptr) { return default_val; } + std::string str{env_val}; + std::transform( + str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); }); + std::stringstream trimmer; + trimmer << str; + str.clear(); + trimmer >> str; + if (str == "libcurl_easy") { return RemoteBackendType::LIBCURL_EASY; } + if (str == "libcurl_multi_poll") { return RemoteBackendType::LIBCURL_MULTI_POLL; } + KVIKIO_FAIL("unknown config value " + std::string{env_var_name} + "=" + std::string{env_val}, + std::invalid_argument); + return {}; +} + template <> std::vector getenv_or(std::string_view env_var_name, std::vector default_val) { @@ -131,20 +152,30 @@ defaults::defaults() } // Determine the default value of `http_status_codes` - { - _http_status_codes = - getenv_or("KVIKIO_HTTP_STATUS_CODES", std::vector{429, 500, 502, 503, 504}); - } + _http_status_codes = + getenv_or("KVIKIO_HTTP_STATUS_CODES", std::vector{429, 500, 502, 503, 504}); // Determine the default value of `auto_direct_io_read` and `auto_direct_io_write` + _auto_direct_io_read = getenv_or("KVIKIO_AUTO_DIRECT_IO_READ", false); + _auto_direct_io_write = getenv_or("KVIKIO_AUTO_DIRECT_IO_WRITE", true); + + // Determine the default value of `thread_pool_per_block_device` + _thread_pool_per_block_device = getenv_or("KVIKIO_THREAD_POOL_PER_BLOCK_DEVICE", false); + + _remote_backend = getenv_or("KVIKIO_REMOTE_BACKEND", RemoteBackendType::LIBCURL_EASY); + { - _auto_direct_io_read = getenv_or("KVIKIO_AUTO_DIRECT_IO_READ", false); - _auto_direct_io_write = getenv_or("KVIKIO_AUTO_DIRECT_IO_WRITE", true); + auto const env = getenv_or("KVIKIO_REMOTE_MAX_CONNECTIONS", 8); + KVIKIO_EXPECT( + env > 0, "KVIKIO_REMOTE_MAX_CONNECTIONS has to be a positive integer", std::invalid_argument); + _remote_max_connections = env; } - // Determine the default value of `thread_pool_per_block_device` { - _thread_pool_per_block_device = getenv_or("KVIKIO_THREAD_POOL_PER_BLOCK_DEVICE", false); + auto const env = getenv_or("KVIKIO_NUM_BOUNCE_BUFFERS", 2); + KVIKIO_EXPECT( + env > 0, "KVIKIO_NUM_BOUNCE_BUFFERS has to be a positive integer", std::invalid_argument); + _num_bounce_buffers = env; } } @@ -250,4 +281,25 @@ void defaults::set_thread_pool_per_block_device(bool flag) { instance()->_thread_pool_per_block_device = flag; } + +RemoteBackendType defaults::remote_backend() { return instance()->_remote_backend; } + +void defaults::set_remote_backend(RemoteBackendType remote_backend) +{ + instance()->_remote_backend = remote_backend; +} + +std::size_t defaults::remote_max_connections() { return instance()->_remote_max_connections; } + +void defaults::set_remote_max_connections(std::size_t remote_max_connections) +{ + instance()->_remote_max_connections = remote_max_connections; +} + +std::size_t defaults::num_bounce_buffers() { return instance()->_num_bounce_buffers; } + +void defaults::set_num_bounce_buffers(std::size_t num_bounce_buffers) +{ + instance()->_num_bounce_buffers = num_bounce_buffers; +} } // namespace kvikio diff --git a/cpp/src/detail/remote_handle.cpp b/cpp/src/detail/remote_handle.cpp index 87d1ed5ab5..aba672458b 100644 --- a/cpp/src/detail/remote_handle.cpp +++ b/cpp/src/detail/remote_handle.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,24 @@ #include namespace kvikio::detail { +void check_curl_easy(CURLcode err_code, char const* filename, int line_number) +{ + if (err_code == CURLcode::CURLE_OK) { return; } + std::stringstream ss; + ss << "libcurl error: " << curl_easy_strerror(err_code) << " at: " << filename << ":" + << line_number << "\n"; + throw std::runtime_error(ss.str()); +} + +void check_curl_multi(CURLMcode err_code, char const* filename, int line_number) +{ + if (err_code == CURLMcode::CURLM_OK) { return; } + std::stringstream ss; + ss << "libcurl error: " << curl_multi_strerror(err_code) << " at: " << filename << ":" + << line_number << "\n"; + throw std::runtime_error(ss.str()); +} + std::size_t callback_get_string_response(char* data, std::size_t size, std::size_t num_bytes, @@ -18,4 +36,11 @@ std::size_t callback_get_string_response(char* data, response->append(data, new_data_size); return new_data_size; } + +void setup_range_request_impl(CurlHandle& curl, std::size_t file_offset, std::size_t size) +{ + std::string const byte_range = + std::to_string(file_offset) + "-" + std::to_string(file_offset + size - 1); + curl.setopt(CURLOPT_RANGE, byte_range.c_str()); +} } // namespace kvikio::detail diff --git a/cpp/src/detail/remote_handle_poll_based.cpp b/cpp/src/detail/remote_handle_poll_based.cpp new file mode 100644 index 0000000000..4dce3f5301 --- /dev/null +++ b/cpp/src/detail/remote_handle_poll_based.cpp @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace kvikio::detail { +namespace { +/** + * @brief Callback function for libcurl's CURLOPT_WRITEFUNCTION. + * + * Called by libcurl when data is received from the remote server. Copies the received data either + * directly to host memory or to a bounce buffer (for device memory destinations). + * + * @param buffer Pointer to the received data. + * @param size Size of each data element (always 1). + * @param nmemb Number of data elements received. + * @param userdata Pointer to the TransferContext for this transfer. + * @return Number of bytes processed, or CURL_WRITEFUNC_ERROR if the received data would overflow + * the expected chunk size. + */ +std::size_t write_callback(char* buffer, std::size_t size, std::size_t nmemb, void* userdata) +{ + auto* ctx = reinterpret_cast(userdata); + std::size_t const nbytes = size * nmemb; + + if (ctx->chunk_size < ctx->bytes_transferred + nbytes) { + ctx->overflow_error = true; + return CURL_WRITEFUNC_ERROR; + } + KVIKIO_NVTX_FUNC_RANGE(nbytes); + void* dst = ctx->is_host_mem ? ctx->buf : ctx->_bounce_buffer_manager->data(); + std::memcpy(static_cast(dst) + ctx->bytes_transferred, buffer, nbytes); + ctx->bytes_transferred += nbytes; + return nbytes; +} + +/** + * @brief Reconfigure a libcurl easy handle for a new chunk transfer. + * + * Resets the transfer context and configures the easy handle to fetch the next chunk using an HTTP + * range request. For device memory destinations, lazily initializes the bounce buffer manager on + * first use. + * + * @param curl_easy_handle The libcurl easy handle to reconfigure. + * @param endpoint Non-owning pointer to the remote endpoint. Must outlive this object. + * @param ctx Transfer context to reset and associate with this chunk. + * @param buf Base destination buffer pointer (host or device memory). + * @param is_host_mem True if buf points to host memory, false for device memory. + * @param current_chunk_idx Zero-based index of the chunk to fetch. + * @param chunk_size Size of each chunk (from defaults::task_size()). + * @param size Total size of the read operation. + * @param file_offset Starting offset in the remote file for the overall read. + * @exception std::runtime_error if setting the CURLOPT_RANGE option fails. + */ +void reconfig_easy_handle(CurlHandle& curl_easy_handle, + RemoteEndpoint* endpoint, + TransferContext* ctx, + void* buf, + bool is_host_mem, + std::size_t current_chunk_idx, + std::size_t chunk_size, + std::size_t size, + std::size_t file_offset) +{ + auto const local_offset = current_chunk_idx * chunk_size; + auto const actual_chunk_size = std::min(chunk_size, size - local_offset); + + ctx->overflow_error = false; + ctx->is_host_mem = is_host_mem; + ctx->buf = static_cast(buf) + local_offset; + ctx->curl_easy_handle = &curl_easy_handle; + ctx->chunk_size = actual_chunk_size; + ctx->bytes_transferred = 0; + + if (!is_host_mem && !ctx->_bounce_buffer_manager.has_value()) { + ctx->_bounce_buffer_manager.emplace(defaults::num_bounce_buffers()); + } + + endpoint->setup_range_request(curl_easy_handle, file_offset + local_offset, actual_chunk_size); +} +} // namespace + +BounceBufferManager::BounceBufferManager(std::size_t num_bounce_buffers) + : _num_bounce_buffers{num_bounce_buffers} +{ + for (std::size_t i = 0; i < _num_bounce_buffers; ++i) { + _bounce_buffers.emplace_back(CudaPinnedBounceBufferPool::instance().get()); + } +} + +void* BounceBufferManager::data() const noexcept +{ + return _bounce_buffers[_bounce_buffer_idx].get(); +} + +void BounceBufferManager::copy(void* dst, std::size_t size, CUstream stream) +{ + KVIKIO_EXPECT(size <= defaults::bounce_buffer_size(), + "Host-to-device copy size exceeds bounce buffer capacity"); + CUDA_DRIVER_TRY( + cudaAPI::instance().MemcpyHtoDAsync(convert_void2deviceptr(dst), data(), size, stream)); + ++_bounce_buffer_idx; + if (_bounce_buffer_idx == _bounce_buffers.size()) { + _bounce_buffer_idx = 0; + CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(stream)); + } +} + +RemoteHandlePollBased::RemoteHandlePollBased(RemoteEndpoint* endpoint, std::size_t max_connections) + : _max_connections{max_connections}, _transfer_ctxs(_max_connections), _endpoint{endpoint} +{ + KVIKIO_EXPECT(defaults::task_size() <= defaults::bounce_buffer_size(), + "bounce buffer size cannot be less than task size."); + + _multi = curl_multi_init(); + KVIKIO_EXPECT(_multi != nullptr, "Failed to initialize libcurl multi API"); + + _curl_easy_handles.reserve(_max_connections); + for (std::size_t i = 0; i < _max_connections; ++i) { + _curl_easy_handles.emplace_back( + std::make_unique(kvikio::LibCurl::instance().get_handle(), + kvikio::detail::fix_conda_file_path_hack(__FILE__), + KVIKIO_STRINGIFY(__LINE__))); + + // Initialize easy handle, associate it with transfer context + _endpoint->setopt(*_curl_easy_handles.back()); + _curl_easy_handles.back()->setopt(CURLOPT_WRITEFUNCTION, write_callback); + _curl_easy_handles.back()->setopt(CURLOPT_WRITEDATA, &_transfer_ctxs[i]); + _curl_easy_handles.back()->setopt(CURLOPT_PRIVATE, &_transfer_ctxs[i]); + } +} + +RemoteHandlePollBased::~RemoteHandlePollBased() noexcept +{ + try { + // Remove any lingering handles before cleanup + for (auto& handle : _curl_easy_handles) { + // Ignore errors + KVIKIO_CHECK_CURL_MULTI(curl_multi_remove_handle(_multi, handle->handle())); + } + KVIKIO_CHECK_CURL_MULTI(curl_multi_cleanup(_multi)); + } catch (std::exception const& e) { + KVIKIO_LOG_ERROR(e.what()); + } +} + +std::size_t RemoteHandlePollBased::pread(void* buf, std::size_t size, std::size_t file_offset) +{ + if (size == 0) return 0; + + std::lock_guard lock{_mutex}; + + // Prepare for the run + bool const is_host_mem = is_host_memory(buf); + auto const chunk_size = defaults::task_size(); + auto const num_chunks = (size + chunk_size - 1) / chunk_size; + auto const actual_max_connections = std::min(_max_connections, num_chunks); + + std::optional cuda_ctx; + CUstream stream{}; + if (!is_host_mem) { + cuda_ctx.emplace(get_context_from_pointer(buf)); + stream = detail::StreamsByThread::get(); + } + + std::size_t num_byte_transferred{0}; + std::size_t current_chunk_idx{0}; + for (std::size_t i = 0; i < actual_max_connections; ++i) { + reconfig_easy_handle(*_curl_easy_handles[i], + _endpoint, + &_transfer_ctxs[i], + buf, + is_host_mem, + current_chunk_idx, + chunk_size, + size, + file_offset); + KVIKIO_CHECK_CURL_MULTI(curl_multi_add_handle(_multi, _curl_easy_handles[i]->handle())); + ++current_chunk_idx; + } + + // Start the run + std::size_t in_flight{actual_max_connections}; + int still_running{0}; + while (in_flight > 0) { + KVIKIO_CHECK_CURL_MULTI(curl_multi_perform(_multi, &still_running)); + + CURLMsg* msg; + int msgs_left; + + // Handle the completed messages + while ((msg = curl_multi_info_read(_multi, &msgs_left))) { + if (msg->msg != CURLMSG_DONE) continue; + + TransferContext* ctx{nullptr}; + KVIKIO_CHECK_CURL_EASY(curl_easy_getinfo(msg->easy_handle, CURLINFO_PRIVATE, &ctx)); + + KVIKIO_EXPECT(ctx != nullptr, "Failed to retrieve transfer context"); + KVIKIO_EXPECT(!ctx->overflow_error, + "Overflow detected. Maybe the server doesn't support file ranges", + std::overflow_error); + KVIKIO_EXPECT(msg->data.result == CURLE_OK, + "Chunked transfer failed in poll-based multi API"); + + if (!is_host_mem) { + ctx->_bounce_buffer_manager->copy(ctx->buf, ctx->bytes_transferred, stream); + } + + num_byte_transferred += ctx->bytes_transferred; + --in_flight; + KVIKIO_CHECK_CURL_MULTI(curl_multi_remove_handle(_multi, msg->easy_handle)); + + if (current_chunk_idx < num_chunks) { + reconfig_easy_handle(*ctx->curl_easy_handle, + _endpoint, + ctx, + buf, + is_host_mem, + current_chunk_idx, + chunk_size, + size, + file_offset); + KVIKIO_CHECK_CURL_MULTI(curl_multi_add_handle(_multi, msg->easy_handle)); + ++current_chunk_idx; + ++in_flight; + } + } + + if (in_flight > 0) { + constexpr int timeout_ms{1000}; + KVIKIO_CHECK_CURL_MULTI(curl_multi_poll(_multi, nullptr, 0, timeout_ms, nullptr)); + } + } + + // Ensure all H2D transfers complete before returning + if (!is_host_mem) { CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(stream)); } + + return num_byte_transferred; +} +} // namespace kvikio::detail diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index 7c917c9a0b..19882800d6 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -21,9 +22,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -157,20 +160,6 @@ std::size_t get_file_size_using_head_impl(RemoteEndpoint& endpoint, std::string return static_cast(cl); } -/** - * @brief Set up the range request for libcurl. Use this method when HTTP range request is supposed. - * - * @param curl A curl handle - * @param file_offset File offset - * @param size read size - */ -void setup_range_request_impl(CurlHandle& curl, std::size_t file_offset, std::size_t size) -{ - std::string const byte_range = - std::to_string(file_offset) + "-" + std::to_string(file_offset + size - 1); - curl.setopt(CURLOPT_RANGE, byte_range.c_str()); -} - /** * @brief Whether the given URL is compatible with the S3 endpoint (including the credential-based * access and presigned URL) which uses HTTP/HTTPS. @@ -253,7 +242,7 @@ std::size_t HttpEndpoint::get_file_size() void HttpEndpoint::setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) { - setup_range_request_impl(curl, file_offset, size); + detail::setup_range_request_impl(curl, file_offset, size); } bool HttpEndpoint::is_url_valid(std::string const& url) noexcept @@ -409,7 +398,7 @@ std::size_t S3Endpoint::get_file_size() void S3Endpoint::setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) { KVIKIO_NVTX_FUNC_RANGE(); - setup_range_request_impl(curl, file_offset, size); + detail::setup_range_request_impl(curl, file_offset, size); } bool S3Endpoint::is_url_valid(std::string const& url) noexcept @@ -456,7 +445,7 @@ void S3PublicEndpoint::setup_range_request(CurlHandle& curl, std::size_t size) { KVIKIO_NVTX_FUNC_RANGE(); - setup_range_request_impl(curl, file_offset, size); + detail::setup_range_request_impl(curl, file_offset, size); } bool S3PublicEndpoint::is_url_valid(std::string const& url) noexcept @@ -552,7 +541,7 @@ void S3EndpointWithPresignedUrl::setup_range_request(CurlHandle& curl, std::size_t size) { KVIKIO_NVTX_FUNC_RANGE(); - setup_range_request_impl(curl, file_offset, size); + detail::setup_range_request_impl(curl, file_offset, size); } bool S3EndpointWithPresignedUrl::is_url_valid(std::string const& url) noexcept @@ -671,18 +660,36 @@ RemoteHandle RemoteHandle::open(std::string url, } RemoteHandle::RemoteHandle(std::unique_ptr endpoint, std::size_t nbytes) - : _endpoint{std::move(endpoint)}, _nbytes{nbytes} + : _endpoint{std::move(endpoint)}, + _nbytes{nbytes}, + _remote_backend_type{defaults::remote_backend()} { KVIKIO_NVTX_FUNC_RANGE(); + if (_remote_backend_type == RemoteBackendType::LIBCURL_MULTI_POLL) { + _poll_handle = std::make_unique( + _endpoint.get(), defaults::remote_max_connections()); + } } RemoteHandle::RemoteHandle(std::unique_ptr endpoint) + : _remote_backend_type{defaults::remote_backend()} { KVIKIO_NVTX_FUNC_RANGE(); _nbytes = endpoint->get_file_size(); _endpoint = std::move(endpoint); + if (_remote_backend_type == RemoteBackendType::LIBCURL_MULTI_POLL) { + _poll_handle = std::make_unique( + _endpoint.get(), defaults::remote_max_connections()); + } } +// Destructor and move operations must be defined in the .cpp file (not defaulted in the header) +// because RemoteHandle uses std::unique_ptr with a forward-declared type. +// The unique_ptr's deleter requires the complete type definition, which is only available here. +RemoteHandle::~RemoteHandle() = default; +RemoteHandle::RemoteHandle(RemoteHandle&& o) = default; +RemoteHandle& RemoteHandle::operator=(RemoteHandle&& o) = default; + RemoteEndpointType RemoteHandle::remote_endpoint_type() const noexcept { return _endpoint->remote_endpoint_type(); @@ -815,6 +822,14 @@ std::future RemoteHandle::pread(void* buf, KVIKIO_EXPECT(thread_pool != nullptr, "The thread pool must not be nullptr"); auto& [nvtx_color, call_idx] = detail::get_next_color_and_call_idx(); KVIKIO_NVTX_FUNC_RANGE(size); + + if (defaults::remote_backend() == RemoteBackendType::LIBCURL_MULTI_POLL) { + return thread_pool->submit_task([=, this] { + KVIKIO_NVTX_SCOPED_RANGE("task_remote_multi_poll", size, nvtx_color); + return _poll_handle->pread(buf, size, file_offset); + }); + } + auto task = [this](void* devPtr_base, std::size_t size, std::size_t file_offset, diff --git a/docs/source/runtime_settings.rst b/docs/source/runtime_settings.rst index 5bbed7f063..b196064bb3 100644 --- a/docs/source/runtime_settings.rst +++ b/docs/source/runtime_settings.rst @@ -121,3 +121,47 @@ Example: # Enable Direct I/O for reads, and disable it for writes kvikio.defaults.set({"auto_direct_io_read": True, "auto_direct_io_write": False}) + +Remote Backend ``KVIKIO_REMOTE_BACKEND`` +---------------------------------------- + +KvikIO supports multiple libcurl-based backends for fetching data from remote endpoints (S3, HTTP, etc.). Set the environment variable ``KVIKIO_REMOTE_BACKEND`` to one of the following options (case-insensitive): + + * ``LIBCURL_EASY``: Use libcurl's easy interface with a thread pool for parallelism. Each chunk is fetched by a separate thread using blocking ``curl_easy_perform()`` calls. This is the default backend. + * ``LIBCURL_MULTI_POLL``: Use libcurl's multi interface with poll-based concurrent transfers. A single call manages multiple concurrent connections using ``curl_multi_poll()``, with k-way buffering to overlap network I/O with host-to-device transfers. + +If not set, the default value is ``LIBCURL_EASY``. + +.. note:: + + Changing this setting after creating a ``RemoteHandle`` has no effect on existing handles. The backend is determined at ``RemoteHandle`` construction time. + +This setting can be queried (:py:func:`kvikio.defaults.get`) and modified (:py:func:`kvikio.defaults.set`) at runtime using the property name ``remote_backend``. + +Remote Max Connections ``KVIKIO_REMOTE_MAX_CONNECTIONS`` +-------------------------------------------------------- + +When using the ``LIBCURL_MULTI_POLL`` backend, this setting controls the maximum number of concurrent HTTP connections used for parallel chunk downloads. Set the environment variable ``KVIKIO_REMOTE_MAX_CONNECTIONS`` to a positive integer. + +If not set, the default value is 8. + +.. note:: + + This setting only applies when using ``RemoteBackendType.LIBCURL_MULTI_POLL``. It has no effect on the ``LIBCURL_EASY`` backend, which uses the thread pool size (``KVIKIO_NTHREADS``) to control parallelism. + +This setting can be queried (:py:func:`kvikio.defaults.get`) and modified (:py:func:`kvikio.defaults.set`) at runtime using the property name ``remote_max_connections``. + +Number of Bounce Buffers ``KVIKIO_NUM_BOUNCE_BUFFERS`` +------------------------------------------------------ + +When using the ``LIBCURL_MULTI_POLL`` backend with device memory destinations, KvikIO uses k-way buffering to overlap network I/O with host-to-device memory transfers. This setting controls the number of bounce buffers allocated per connection. + +Set the environment variable ``KVIKIO_NUM_BOUNCE_BUFFERS`` to a positive integer. Higher values allow more overlap between network I/O and H2D transfers but consume more pinned memory. The total pinned memory usage is ``remote_max_connections * num_bounce_buffers * bounce_buffer_size``. + +If not set, the default value is 2. + +.. note:: + + This setting only applies when using ``RemoteBackendType.LIBCURL_MULTI_POLL`` with device memory destinations. For host memory destinations or the ``LIBCURL_EASY`` backend, bounce buffers are not used. + +This setting can be queried (:py:func:`kvikio.defaults.get`) and modified (:py:func:`kvikio.defaults.set`) at runtime using the property name ``num_bounce_buffers``. diff --git a/python/kvikio/kvikio/__init__.py b/python/kvikio/kvikio/__init__.py index e00aa40c14..e42e5e583d 100644 --- a/python/kvikio/kvikio/__init__.py +++ b/python/kvikio/kvikio/__init__.py @@ -12,7 +12,7 @@ del libkvikio -from kvikio._lib.defaults import CompatMode # noqa: F401 +from kvikio._lib.defaults import CompatMode, RemoteBackendType # noqa: F401 from kvikio._version import __git_commit__, __version__ from kvikio.buffer import bounce_buffer_free, memory_deregister, memory_register from kvikio.cufile import CuFile, clear_page_cache, get_page_cache_info diff --git a/python/kvikio/kvikio/_lib/defaults.pyx b/python/kvikio/kvikio/_lib/defaults.pyx index e5bfbca713..2752356dbe 100644 --- a/python/kvikio/kvikio/_lib/defaults.pyx +++ b/python/kvikio/kvikio/_lib/defaults.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # distutils: language = c++ @@ -14,6 +14,9 @@ cdef extern from "" namespace "kvikio" nogil: OFF = 0 ON = 1 AUTO = 2 + cpdef enum class RemoteBackendType(uint8_t): + LIBCURL_EASY = 0 + LIBCURL_MULTI_POLL = 1 bool cpp_is_compat_mode_preferred \ "kvikio::defaults::is_compat_mode_preferred"() except + CompatMode cpp_compat_mode "kvikio::defaults::compat_mode"() except + @@ -42,10 +45,19 @@ cdef extern from "" namespace "kvikio" nogil: "kvikio::defaults::set_http_timeout"(long timeout_seconds) except + bool cpp_auto_direct_io_read "kvikio::defaults::auto_direct_io_read"() except + void cpp_set_auto_direct_io_read \ - "kvikio::defaults::set_auto_direct_io_read"(size_t flag) except + + "kvikio::defaults::set_auto_direct_io_read"(bool flag) except + bool cpp_auto_direct_io_write "kvikio::defaults::auto_direct_io_write"() except + void cpp_set_auto_direct_io_write \ - "kvikio::defaults::set_auto_direct_io_write"(size_t flag) except + + "kvikio::defaults::set_auto_direct_io_write"(bool flag) except + + RemoteBackendType cpp_remote_backend "kvikio::defaults::remote_backend"() except + + void cpp_set_remote_backend \ + "kvikio::defaults::set_remote_backend"(RemoteBackendType remote_backend) except + + size_t cpp_remote_max_connections "kvikio::defaults::remote_max_connections"() except + + void cpp_set_remote_max_connections \ + "kvikio::defaults::set_remote_max_connections"(size_t remote_max_connections) except + + size_t cpp_num_bounce_buffers "kvikio::defaults::num_bounce_buffers"() except + + void cpp_set_num_bounce_buffers \ + "kvikio::defaults::set_num_bounce_buffers"(size_t num_bounce_buffers) except + def is_compat_mode_preferred() -> bool: @@ -179,3 +191,42 @@ def set_auto_direct_io_write(flag: bool) -> None: cdef bool cpp_flag = flag with nogil: cpp_set_auto_direct_io_write(cpp_flag) + + +def remote_backend() -> RemoteBackendType: + cdef RemoteBackendType result + with nogil: + result = cpp_remote_backend() + return result + + +def set_remote_backend(remote_backend: RemoteBackendType) -> None: + cdef RemoteBackendType cpp_remote_backend = remote_backend + with nogil: + cpp_set_remote_backend(cpp_remote_backend) + + +def remote_max_connections() -> int: + cdef size_t result + with nogil: + result = cpp_remote_max_connections() + return result + + +def set_remote_max_connections(remote_max_connections: int) -> None: + cdef size_t cpp_remote_max_connections = remote_max_connections + with nogil: + cpp_set_remote_max_connections(cpp_remote_max_connections) + + +def num_bounce_buffers() -> int: + cdef size_t result + with nogil: + result = cpp_num_bounce_buffers() + return result + + +def set_num_bounce_buffers(num_bounce_buffers: int) -> None: + cdef size_t cpp_num_bounce_buffers = num_bounce_buffers + with nogil: + cpp_set_num_bounce_buffers(cpp_num_bounce_buffers) diff --git a/python/kvikio/kvikio/defaults.py b/python/kvikio/kvikio/defaults.py index 3af9bd0929..ae8c96a58f 100644 --- a/python/kvikio/kvikio/defaults.py +++ b/python/kvikio/kvikio/defaults.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -58,6 +58,9 @@ def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]: "http_timeout", "auto_direct_io_read", "auto_direct_io_write", + "remote_backend", + "remote_max_connections", + "num_bounce_buffers", ] property_getters = {} @@ -127,6 +130,9 @@ def set(*config) -> ConfigContextManager: - ``"http_timeout"`` - ``"auto_direct_io_read"`` - ``"auto_direct_io_write"`` + - ``"remote_backend"`` + - ``"remote_max_connections"`` + - ``"num_bounce_buffers"`` Returns ------- @@ -172,6 +178,9 @@ def get(config_name: str) -> Any: - ``"http_timeout"`` - ``"auto_direct_io_read"`` - ``"auto_direct_io_write"`` + - ``"remote_backend"`` + - ``"remote_max_connections"`` + - ``"num_bounce_buffers"`` Returns ------- diff --git a/python/kvikio/tests/test_s3_io.py b/python/kvikio/tests/test_s3_io.py index d8610c73bc..0cfd9f63ec 100644 --- a/python/kvikio/tests/test_s3_io.py +++ b/python/kvikio/tests/test_s3_io.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import multiprocessing as mp @@ -112,7 +112,14 @@ def test_read_access(s3_base): @pytest.mark.parametrize("nthreads", [1, 3]) @pytest.mark.parametrize("tasksize", [99, 999]) @pytest.mark.parametrize("buffer_size", [101, 1001]) -def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): +@pytest.mark.parametrize( + "remote_backend", + [ + kvikio.RemoteBackendType.LIBCURL_EASY, + kvikio.RemoteBackendType.LIBCURL_MULTI_POLL, + ], +) +def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size, remote_backend): bucket_name = "test_read" object_name = "Aa1" a = xp.arange(size) @@ -124,8 +131,16 @@ def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): "num_threads": nthreads, "task_size": tasksize, "bounce_buffer_size": buffer_size, + "remote_backend": remote_backend, } ): + if ( + remote_backend == kvikio.RemoteBackendType.LIBCURL_MULTI_POLL + and tasksize > buffer_size + ): + pytest.skip( + "When the remote backend is LIBCURL_MULTI_POLL, task size must not be greater than the buffer size" + ) with kvikio.RemoteFile.open_s3_url( f"{server_address}/{bucket_name}/{object_name}" ) as f: @@ -144,18 +159,26 @@ def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): (42, int(2**20)), ], ) -def test_read_with_file_offset(s3_base, xp, start, end): +@pytest.mark.parametrize( + "remote_backend", + [ + kvikio.RemoteBackendType.LIBCURL_EASY, + kvikio.RemoteBackendType.LIBCURL_MULTI_POLL, + ], +) +def test_read_with_file_offset(s3_base, xp, start, end, remote_backend): bucket_name = "test_read_with_file_offset" object_name = "Aa1" a = xp.arange(end, dtype=xp.int64) with s3_context( s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} ) as server_address: - url = f"{server_address}/{bucket_name}/{object_name}" - with kvikio.RemoteFile.open_s3_url(url) as f: - b = xp.zeros(shape=(end - start,), dtype=xp.int64) - assert f.read(b, file_offset=start * a.itemsize) == b.nbytes - xp.testing.assert_array_equal(a[start:end], b) + with kvikio.defaults.set({"remote_backend": remote_backend}): + url = f"{server_address}/{bucket_name}/{object_name}" + with kvikio.RemoteFile.open_s3_url(url) as f: + b = xp.zeros(shape=(end - start,), dtype=xp.int64) + assert f.read(b, file_offset=start * a.itemsize) == b.nbytes + xp.testing.assert_array_equal(a[start:end], b) @pytest.mark.parametrize("scheme", ["S3"])