Skip to content
Open
23 changes: 19 additions & 4 deletions cpp/include/rapidsmpf/shuffler/shuffler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <vector>

#include <rapidsmpf/communicator/communicator.hpp>
#include <rapidsmpf/communicator/metadata_payload_exchange/tag.hpp>
#include <rapidsmpf/error.hpp>
#include <rapidsmpf/memory/buffer_resource.hpp>
#include <rapidsmpf/memory/packed_data.hpp>
Expand Down Expand Up @@ -112,6 +113,8 @@ class Shuffler {
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param finished_callback Callback to notify when a partition is finished.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -123,7 +126,8 @@ class Shuffler {
PartID total_num_partitions,
BufferResource* br,
FinishedCallback&& finished_callback,
PartitionOwner partition_owner = round_robin
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
);

/**
Expand All @@ -135,6 +139,8 @@ class Shuffler {
* @param total_num_partitions Total number of partitions in the shuffle.
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -145,9 +151,18 @@ class Shuffler {
OpID op_id,
PartID total_num_partitions,
BufferResource* br,
PartitionOwner partition_owner = round_robin
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
)
: Shuffler(comm, op_id, total_num_partitions, br, nullptr, partition_owner) {}
: Shuffler(
comm,
op_id,
total_num_partitions,
br,
nullptr,
partition_owner,
std::move(mpe)
) {}

~Shuffler();

Expand Down Expand Up @@ -334,8 +349,8 @@ class Shuffler {
///< ready to be extracted by the user.

std::shared_ptr<Communicator> comm_;
std::unique_ptr<communicator::MetadataPayloadExchange> mpe_;
ProgressThread::FunctionID progress_thread_function_id_;
OpID const op_id_;

SpillManager::SpillFunctionID spill_function_id_;

Expand Down
22 changes: 15 additions & 7 deletions cpp/src/communicator/metadata_payload_exchange/tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,16 @@ void TagMetadataPayloadExchange::send(
);

// Send data immediately after metadata (if any)
if (message->data() != nullptr) {
if (payload_size > 0) {
fire_and_forget_.push_back(
comm_->send(message->release_data(), dst, gpu_data_tag_)
);
}
}

statistics_->add_duration_stat("comms-interface-send-messages", Clock::now() - t0);
statistics_->add_duration_stat(
"metadata-payload-exchange-send-messages", Clock::now() - t0
);
}

void TagMetadataPayloadExchange::progress() {
Expand All @@ -110,7 +112,9 @@ void TagMetadataPayloadExchange::progress() {

cleanup_completed_operations();

statistics_->add_duration_stat("comms-interface-progress", Clock::now() - t0);
statistics_->add_duration_stat(
"metadata-payload-exchange-progress", Clock::now() - t0
);
}

std::vector<std::unique_ptr<MetadataPayloadExchange::Message>>
Expand Down Expand Up @@ -176,7 +180,9 @@ void TagMetadataPayloadExchange::receive_metadata() {
);
}

statistics_->add_duration_stat("comms-interface-receive-metadata", Clock::now() - t0);
statistics_->add_duration_stat(
"metadata-payload-exchange-receive-metadata", Clock::now() - t0
);
}

std::vector<std::unique_ptr<MetadataPayloadExchange::Message>>
Expand Down Expand Up @@ -257,7 +263,7 @@ TagMetadataPayloadExchange::setup_data_receives() {
}

statistics_->add_duration_stat(
"comms-interface-setup-data-receives", Clock::now() - t0
"metadata-payload-exchange-setup-data-receives", Clock::now() - t0
);

return completed_messages;
Expand Down Expand Up @@ -318,14 +324,16 @@ TagMetadataPayloadExchange::complete_data_transfers() {
}

statistics_->add_duration_stat(
"comms-interface-complete-data-transfers", Clock::now() - t0
"metadata-payload-exchange-complete-data-transfers", Clock::now() - t0
);

return completed_messages;
}

void TagMetadataPayloadExchange::cleanup_completed_operations() {
std::ignore = comm_->test_some(fire_and_forget_);
if (!fire_and_forget_.empty()) {
std::ignore = comm_->test_some(fire_and_forget_);
}
}


Expand Down
Loading
Loading