diff --git a/cpp/include/rapidsmpf/shuffler/shuffler.hpp b/cpp/include/rapidsmpf/shuffler/shuffler.hpp index ff9bdecab..d897f5e7c 100644 --- a/cpp/include/rapidsmpf/shuffler/shuffler.hpp +++ b/cpp/include/rapidsmpf/shuffler/shuffler.hpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -112,6 +113,8 @@ class Shuffler { * @param br Buffer resource used to allocate temporary and the shuffle result. * @param finished_callback Callback to notify when a partition is finished. * @param partition_owner Function to determine partition ownership. + * @param mpe Optional custom metadata payload exchange. If not provided, + * uses the default tag-based implementation. * * @note The caller promises that inserted buffers are stream-ordered with respect * to their own stream, and extracted buffers are likewise guaranteed to be stream- @@ -123,7 +126,8 @@ class Shuffler { PartID total_num_partitions, BufferResource* br, FinishedCallback&& finished_callback, - PartitionOwner partition_owner = round_robin + PartitionOwner partition_owner = round_robin, + std::unique_ptr mpe = nullptr ); /** @@ -135,6 +139,8 @@ class Shuffler { * @param total_num_partitions Total number of partitions in the shuffle. * @param br Buffer resource used to allocate temporary and the shuffle result. * @param partition_owner Function to determine partition ownership. + * @param mpe Optional custom metadata payload exchange. If not provided, + * uses the default tag-based implementation. * * @note The caller promises that inserted buffers are stream-ordered with respect * to their own stream, and extracted buffers are likewise guaranteed to be stream- @@ -145,9 +151,18 @@ class Shuffler { OpID op_id, PartID total_num_partitions, BufferResource* br, - PartitionOwner partition_owner = round_robin + PartitionOwner partition_owner = round_robin, + std::unique_ptr mpe = nullptr ) - : Shuffler(comm, op_id, total_num_partitions, br, nullptr, partition_owner) {} + : Shuffler( + comm, + op_id, + total_num_partitions, + br, + nullptr, + partition_owner, + std::move(mpe) + ) {} ~Shuffler(); @@ -334,8 +349,8 @@ class Shuffler { ///< ready to be extracted by the user. std::shared_ptr comm_; + std::unique_ptr mpe_; ProgressThread::FunctionID progress_thread_function_id_; - OpID const op_id_; SpillManager::SpillFunctionID spill_function_id_; diff --git a/cpp/src/communicator/metadata_payload_exchange/tag.cpp b/cpp/src/communicator/metadata_payload_exchange/tag.cpp index bcb58df63..c64ddab85 100644 --- a/cpp/src/communicator/metadata_payload_exchange/tag.cpp +++ b/cpp/src/communicator/metadata_payload_exchange/tag.cpp @@ -86,14 +86,16 @@ void TagMetadataPayloadExchange::send( ); // Send data immediately after metadata (if any) - if (message->data() != nullptr) { + if (payload_size > 0) { fire_and_forget_.push_back( comm_->send(message->release_data(), dst, gpu_data_tag_) ); } } - statistics_->add_duration_stat("comms-interface-send-messages", Clock::now() - t0); + statistics_->add_duration_stat( + "metadata-payload-exchange-send-messages", Clock::now() - t0 + ); } void TagMetadataPayloadExchange::progress() { @@ -110,7 +112,9 @@ void TagMetadataPayloadExchange::progress() { cleanup_completed_operations(); - statistics_->add_duration_stat("comms-interface-progress", Clock::now() - t0); + statistics_->add_duration_stat( + "metadata-payload-exchange-progress", Clock::now() - t0 + ); } std::vector> @@ -176,7 +180,9 @@ void TagMetadataPayloadExchange::receive_metadata() { ); } - statistics_->add_duration_stat("comms-interface-receive-metadata", Clock::now() - t0); + statistics_->add_duration_stat( + "metadata-payload-exchange-receive-metadata", Clock::now() - t0 + ); } std::vector> @@ -257,7 +263,7 @@ TagMetadataPayloadExchange::setup_data_receives() { } statistics_->add_duration_stat( - "comms-interface-setup-data-receives", Clock::now() - t0 + "metadata-payload-exchange-setup-data-receives", Clock::now() - t0 ); return completed_messages; @@ -318,14 +324,16 @@ TagMetadataPayloadExchange::complete_data_transfers() { } statistics_->add_duration_stat( - "comms-interface-complete-data-transfers", Clock::now() - t0 + "metadata-payload-exchange-complete-data-transfers", Clock::now() - t0 ); return completed_messages; } void TagMetadataPayloadExchange::cleanup_completed_operations() { - std::ignore = comm_->test_some(fire_and_forget_); + if (!fire_and_forget_.empty()) { + std::ignore = comm_->test_some(fire_and_forget_); + } } diff --git a/cpp/src/shuffler/shuffler.cpp b/cpp/src/shuffler/shuffler.cpp index 936a86ee5..1182b4f05 100644 --- a/cpp/src/shuffler/shuffler.cpp +++ b/cpp/src/shuffler/shuffler.cpp @@ -12,6 +12,8 @@ #include #include +#include +#include #include #include #include @@ -26,6 +28,41 @@ using namespace detail; namespace { +/** + * @brief Convert chunks into messages for communication. + * + * This function converts a vector of chunks into messages suitable for sending + * through the metadata payload exchange. Each chunk is serialized and its data + * buffer is released to create the message. + * + * @param chunks Vector of chunks to convert (will be moved from). + * @param peer_rank_fn Function to determine the destination rank for each chunk. + * + * @return A vector of message unique pointers ready to be sent. + */ +template +std::vector> +convert_chunks_to_messages( + std::vector&& chunks, PeerRankFn&& peer_rank_fn +) { + std::vector> messages; + messages.reserve(chunks.size()); + + for (auto&& chunk : chunks) { + auto dst = peer_rank_fn(chunk); + auto metadata = *chunk.serialize(); + auto data = chunk.release_data_buffer(); + + messages.push_back( + std::make_unique( + dst, std::move(metadata), std::move(data) + ) + ); + } + + return messages; +} + /** * @brief Spills memory buffers within a postbox, e.g., from device to host memory. * @@ -101,153 +138,77 @@ class Shuffler::Progress { RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("Shuffler.Progress", p_iters++); auto const t0_event_loop = Clock::now(); - // Tags for each stage of the shuffle - Tag const metadata_tag{shuffler_.op_id_, 0}; - Tag const gpu_data_tag{shuffler_.op_id_, 1}; - - auto& log = *shuffler_.comm_->logger(); auto& stats = *shuffler_.statistics_; + // Submit outgoing chunks to the metadata payload exchange { - auto const t0_send = Clock::now(); + auto const t0_submit_outgoing = Clock::now(); auto ready_chunks = shuffler_.outgoing_postbox_.extract_all_ready(); - RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("shuffle_send", ready_chunks.size()); - for (auto&& chunk : ready_chunks) { - auto dst = shuffler_.partition_owner( - shuffler_.comm_, chunk.part_id(), shuffler_.total_num_partitions - ); - log.trace("send to ", dst, ": ", chunk); - RAPIDSMPF_EXPECTS( - dst != shuffler_.comm_->rank(), "sending chunk to ourselves" - ); + RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("submit_outgoing", ready_chunks.size()); - fire_and_forget_.push_back( - shuffler_.comm_->send(chunk.serialize(), dst, metadata_tag) - ); - if (chunk.data_size() > 0) { - shuffler_.statistics_->add_bytes_stat( - "shuffle-payload-send", chunk.data_size() + if (!ready_chunks.empty()) { + auto peer_rank_fn = [&shuffler = + shuffler_](detail::Chunk const& chunk) -> Rank { + auto dst = shuffler.partition_owner( + shuffler.comm_, chunk.part_id(), shuffler.total_num_partitions + ); + shuffler.comm_->logger()->trace( + "submitting message to ", dst, ": ", chunk ); - fire_and_forget_.push_back(shuffler_.comm_->send( - chunk.release_data_buffer(), dst, gpu_data_tag - )); - } - } - stats.add_duration_stat("event-loop-send", Clock::now() - t0_send); - } - - // Receive any incoming metadata of remote chunks and place them in - // `incoming_chunks_`. - { - auto const t0_metadata_recv = Clock::now(); - RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("meta_recv"); - [[maybe_unused]] int recv_any_iters = - 0; // this will be stripped off if RAPIDSMPF_VERBOSE_INFO is not set - while (true) { - auto const [msg, src] = shuffler_.comm_->recv_any(metadata_tag); - if (msg) { - auto chunk = - Chunk::deserialize(*msg, shuffler_.br_, /*validate=*/false); - log.trace("recv_any from ", src, ": ", chunk); RAPIDSMPF_EXPECTS( - shuffler_.partition_owner( - shuffler_.comm_, - chunk.part_id(), - shuffler_.total_num_partitions - ) == shuffler_.comm_->rank(), - "receiving chunk not owned by us" + dst != shuffler.comm_->rank(), "sending message to ourselves" ); - incoming_chunks_[src].push_back(std::move(chunk)); - } else { - break; - } - recv_any_iters++; - } - stats.add_duration_stat( - "event-loop-metadata-recv", Clock::now() - t0_metadata_recv - ); - RAPIDSMPF_NVTX_MARKER_VERBOSE("meta_recv_iters", recv_any_iters); - } - - // Post receives for incoming chunks. Note that we start the allocation of chunks - // in received message order, but because the allocations run on different streams - // they might not complete and be ready in that order. To handle that, we separate - // incoming chunks by rank and then process chunks in FIFO order until we observe - // a non-ready chunk. - { - auto const t0_post_incoming_chunk_recv = Clock::now(); - for (auto& [src, chunks] : incoming_chunks_) { - RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("post_chunk_recv", chunks.size()); - std::ptrdiff_t n_processed = 0; - for (auto& chunk : chunks) { - log.trace("checking incoming chunk data from ", src, ": ", chunk); + return dst; + }; + for (auto const& chunk : ready_chunks) { if (chunk.data_size() > 0) { - if (!chunk.is_ready()) { - break; - } - - auto chunk_id = chunk.chunk_id(); - auto data_size = chunk.data_size(); - - // Setup to receive the chunk into `in_transit_*`. - auto future = shuffler_.comm_->recv( - src, gpu_data_tag, chunk.release_data_buffer() - ); - RAPIDSMPF_EXPECTS( - in_transit_futures_.emplace(chunk_id, std::move(future)) - .second, - "in transit future already exist" - ); - RAPIDSMPF_EXPECTS( - in_transit_chunks_.emplace(chunk_id, std::move(chunk)).second, - "in transit chunk already exist" - ); - shuffler_.statistics_->add_bytes_stat( - "shuffle-payload-recv", data_size - ); - } else { - // Control messages and metadata-only messages go - // directly to the ready postbox. - shuffler_.insert_into_ready_postbox(std::move(chunk)); + stats.add_bytes_stat("shuffle-payload-send", chunk.data_size()); } - n_processed++; } - chunks.erase(chunks.begin(), chunks.begin() + n_processed); - } + auto messages = + convert_chunks_to_messages(std::move(ready_chunks), peer_rank_fn); + + shuffler_.mpe_->send(std::move(messages)); + } stats.add_duration_stat( - "event-loop-post-incoming-chunk-recv", - Clock::now() - t0_post_incoming_chunk_recv + "event-loop-submit-outgoing", Clock::now() - t0_submit_outgoing ); } - // Check if any data in transit is finished. + // Process all communication operations and get completed chunks { - auto const t0_check_future_finish = Clock::now(); - RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE( - "check_fut_finish", in_transit_futures_.size() - ); - if (!in_transit_futures_.empty()) { - std::vector finished = - shuffler_.comm_->test_some(in_transit_futures_); - for (auto cid : finished) { - auto chunk = extract_value(in_transit_chunks_, cid); - auto future = extract_value(in_transit_futures_, cid); - chunk.set_data_buffer( - shuffler_.comm_->release_data(std::move(future)) - ); + auto const t0_process_comm = Clock::now(); + RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("process_communication"); + + shuffler_.mpe_->progress(); + auto completed_messages = shuffler_.mpe_->recv(); + + for (auto&& message : completed_messages) { + auto chunk = + detail::Chunk::deserialize(message->metadata(), shuffler_.br_, false); + if (message->data() != nullptr) { + std::ignore = chunk.release_data_buffer(); + chunk.set_data_buffer(message->release_data()); + } - shuffler_.insert_into_ready_postbox(std::move(chunk)); + RAPIDSMPF_EXPECTS( + shuffler_.partition_owner( + shuffler_.comm_, chunk.part_id(), shuffler_.total_num_partitions + ) == shuffler_.comm_->rank(), + "receiving chunk not owned by us" + ); + + if (chunk.data_size() > 0) { + stats.add_bytes_stat("shuffle-payload-recv", chunk.data_size()); } - } - // Check if we can free some of the outstanding futures. - if (!fire_and_forget_.empty()) { - std::ignore = shuffler_.comm_->test_some(fire_and_forget_); + shuffler_.insert_into_ready_postbox(std::move(chunk)); } + stats.add_duration_stat( - "event-loop-check-future-finish", Clock::now() - t0_check_future_finish + "event-loop-process-communication", Clock::now() - t0_process_comm ); } @@ -256,28 +217,13 @@ class Shuffler::Progress { // Return Done only if the shuffler is inactive (shutdown was called) _and_ // all containers are empty (all work is done). return (shuffler_.active_.load(std::memory_order_acquire) - || !( - fire_and_forget_.empty() - && std::ranges::all_of( - incoming_chunks_, [](auto const& kv) { return kv.second.empty(); } - ) - && in_transit_chunks_.empty() && in_transit_futures_.empty() - && shuffler_.outgoing_postbox_.empty() - )) + || !shuffler_.mpe_->is_idle() || !shuffler_.outgoing_postbox_.empty()) ? ProgressThread::ProgressState::InProgress : ProgressThread::ProgressState::Done; } private: Shuffler& shuffler_; - std::vector> - fire_and_forget_; ///< Ongoing "fire-and-forget" operations (non-blocking sends). - std::unordered_map> - incoming_chunks_; ///< Per-rank FIFO of chunks awaiting receive. - std::unordered_map - in_transit_chunks_; ///< Chunks currently in transit. - std::unordered_map> - in_transit_futures_; ///< Futures corresponding to in-transit chunks. #if RAPIDSMPF_VERBOSE_INFO std::int64_t p_iters = 0; ///< Number of progress iterations (for NVTX) @@ -304,7 +250,8 @@ Shuffler::Shuffler( PartID total_num_partitions, BufferResource* br, FinishedCallback&& finished_callback, - PartitionOwner partition_owner_fn + PartitionOwner partition_owner_fn, + std::unique_ptr mpe ) : total_num_partitions{total_num_partitions}, partition_owner{std::move(partition_owner_fn)}, @@ -320,7 +267,20 @@ Shuffler::Shuffler( safe_cast(total_num_partitions), }, comm_{std::move(comm)}, - op_id_{op_id}, + mpe_{ + mpe ? std::move(mpe) + : std::make_unique( + comm_, + op_id, + [this](std::size_t size) -> std::unique_ptr { + return br_->allocate( + br_->stream_pool().get_stream(), + br_->reserve_or_fail(size, MEMORY_TYPES) + ); + }, + br_->statistics() + ) + }, local_partitions_{local_partitions(comm_, total_num_partitions, partition_owner)}, finish_counter_{comm_->nranks(), local_partitions_, std::move(finished_callback)}, outbound_chunk_counter_(safe_cast(comm_->nranks()), 0),