From 1533ac236738ef9fdbbae9e30176927f55a9087a Mon Sep 17 00:00:00 2001 From: Nathan Delisle Date: Wed, 8 Oct 2025 21:19:57 -0500 Subject: [PATCH 1/2] training: avoid copying std::future and make ThreadPool::run move-friendly (bugfix) --- .../training/clustering/compression_utils.cpp | 11 ++++-- tools/training/utils/thread_pool.h | 36 ++++++++++++------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/tools/training/clustering/compression_utils.cpp b/tools/training/clustering/compression_utils.cpp index 6b11cb76..7a71d3ca 100644 --- a/tools/training/clustering/compression_utils.cpp +++ b/tools/training/clustering/compression_utils.cpp @@ -7,6 +7,7 @@ #include #include "openzl/common/a1cbor_helpers.h" #include "openzl/common/allocation.h" +#include #include "openzl/common/operation_context.h" #include "openzl/cpp/CCtx.hpp" @@ -197,10 +198,14 @@ std::future CompressionUtils::tryCompress( }; futures.emplace_back(threadPool_->run(task, configPtr, funcPtr)); } - return threadPool_->run([futures = std::move(futures)]() mutable { + // Make the lambda copyable for MSVC's packaged_task by capturing a shared_ptr: + auto futures_ptr = + std::make_shared>>(std::move(futures)); + + return threadPool_->run([futures_ptr]() { SizeTimePair result{}; - for (auto& future : futures) { - result = result + future.get(); + for (auto& fut : *futures_ptr) { + result = result + fut.get(); } return result; }); diff --git a/tools/training/utils/thread_pool.h b/tools/training/utils/thread_pool.h index 287b3d5e..a703eab4 100644 --- a/tools/training/utils/thread_pool.h +++ b/tools/training/utils/thread_pool.h @@ -5,9 +5,13 @@ #include #include #include +#include +#include +#include #include #include #include +#include namespace openzl::training { class ThreadPool { @@ -33,21 +37,27 @@ class ThreadPool { * @return std::future A future object that holds the result of * the task. */ - template - auto run(Func&& func, Args&&... args) - -> std::future - { - using ReturnType = decltype(func(args...)); - auto task = - std::make_shared>(std::bind( - std::forward(func), std::forward(args)...)); - std::future result = task->get_future(); + template + auto run(F&& f, Args&&... args) + -> std::future, std::decay_t...>> { + using ReturnType = std::invoke_result_t, std::decay_t...>; + using Task = std::packaged_task; + + // Capture callable + args by MOVE into a tuple (no copies of move-only types). + auto bound = [fn = std::forward(f), + tup = std::make_tuple(std::forward(args)...)]() mutable -> ReturnType { + return std::apply(fn, tup); + }; + + auto task = std::make_shared(std::move(bound)); + auto fut = task->get_future(); + { - std::lock_guard lock(queueMutex_); - taskQueue_.emplace([task]() { (*task)(); }); + std::lock_guard lock(queueMutex_); // correct name + taskQueue_.emplace([task]() mutable { (*task)(); }); // correct name } - condition_.notify_one(); - return result; + condition_.notify_one(); // correct name + return fut; } const size_t numThreads; From b040c8352840ac53dd8991a57c4c2881c8ab2bcd Mon Sep 17 00:00:00 2001 From: Nathan Delisle Date: Wed, 8 Oct 2025 21:33:21 -0500 Subject: [PATCH 2/2] style: clang-format --- .../training/clustering/compression_utils.cpp | 356 +++++++++--------- tools/training/utils/thread_pool.h | 95 ++--- 2 files changed, 219 insertions(+), 232 deletions(-) diff --git a/tools/training/clustering/compression_utils.cpp b/tools/training/clustering/compression_utils.cpp index 7a71d3ca..67b3c039 100644 --- a/tools/training/clustering/compression_utils.cpp +++ b/tools/training/clustering/compression_utils.cpp @@ -1,13 +1,13 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. #include "tools/training/clustering/compression_utils.h" +#include "openzl/common/a1cbor_helpers.h" +#include "openzl/common/allocation.h" #include #include +#include #include #include -#include "openzl/common/a1cbor_helpers.h" -#include "openzl/common/allocation.h" -#include #include "openzl/common/operation_context.h" #include "openzl/cpp/CCtx.hpp" @@ -17,210 +17,196 @@ namespace openzl::training { namespace { -int getTag(const Input& input) -{ - auto meta = input.getIntMetadata(ZL_CLUSTERING_TAG_METADATA_ID); - if (!meta) { - throw Exception("Stream provided has no metadata"); - } - return meta.value(); +int getTag(const Input &input) { + auto meta = input.getIntMetadata(ZL_CLUSTERING_TAG_METADATA_ID); + if (!meta) { + throw Exception("Stream provided has no metadata"); + } + return meta.value(); } -ColumnInfo getColumnInfo(const Input& input) -{ - return (ColumnInfo){ - .tag = getTag(input), - .type = typeToCType(input.type()), - .width = input.eltWidth(), - }; +ColumnInfo getColumnInfo(const Input &input) { + return (ColumnInfo){ + .tag = getTag(input), + .type = typeToCType(input.type()), + .width = input.eltWidth(), + }; } } // namespace -ClusterInfo CompressionUtils::getBestClusterInfo( - const std::unordered_set& tags, - ZL_Type type, - size_t eltWidth, - const ColumnMetadata& metadata) const -{ - ClusterInfo bestClusterInfo; - if (tags.size() == 0) { - throw Exception("No tags provided"); - } - // Check that there is a config for each tag - for (auto& tag : tags) { - auto column = - (ColumnInfo){ .tag = tag, .type = type, .width = eltWidth }; - if (metadata.count(column) == 0) { - throw std::runtime_error( - "No tag found in metadata for provided type and eltWidth"); - } +ClusterInfo +CompressionUtils::getBestClusterInfo(const std::unordered_set &tags, + ZL_Type type, size_t eltWidth, + const ColumnMetadata &metadata) const { + ClusterInfo bestClusterInfo; + if (tags.size() == 0) { + throw Exception("No tags provided"); + } + // Check that there is a config for each tag + for (auto &tag : tags) { + auto column = (ColumnInfo){.tag = tag, .type = type, .width = eltWidth}; + if (metadata.count(column) == 0) { + throw std::runtime_error( + "No tag found in metadata for provided type and eltWidth"); } + } - // Set to compress only the relevant successors - std::function filter = [tags](ColumnInfo val) { - return tags.count(val.tag) != 0; - }; + // Set to compress only the relevant successors + std::function filter = [tags](ColumnInfo val) { + return tags.count(val.tag) != 0; + }; - auto configBuilder = - ClusteringConfigBuilder::buildConfigSingleClusterWithSuccessor( - tags, type, eltWidth, 0, 0); - // Set up a config that clusters tags together - for (size_t i = 0; i < successors_.size(); i++) { - configBuilder.setClusterSuccessor(0, i); - auto succType = - ZL_Compressor_Graph_getInput0Mask(compressor_, successors_[i]); - // If the type is serial, allow automatic conversion from numeric/struct - if (succType & 0b1) { - succType = (ZL_Type)(succType | 0b110); - } - if (!(type & succType) - || typeToClusteringCodecIdxsMap_.count(type) == 0) { - continue; - } - auto clusteringCodecIdxs = typeToClusteringCodecIdxsMap_.at(type); - for (size_t j = 0; j < clusteringCodecIdxs.size(); j++) { - SizeTimePair cost{ 0, 0 }; - configBuilder.setClusteringCodec(0, clusteringCodecIdxs[j]); - auto config = configBuilder.build(); - for (auto& sample : samples_) { - cost = cost + compressSample(config, filter, sample); - } - if (cost < bestClusterInfo.cost) { - bestClusterInfo = { .successorIdx = i, - .clusteringCodecIdx = - clusteringCodecIdxs[j], - .cost = cost }; - } - } + auto configBuilder = + ClusteringConfigBuilder::buildConfigSingleClusterWithSuccessor( + tags, type, eltWidth, 0, 0); + // Set up a config that clusters tags together + for (size_t i = 0; i < successors_.size(); i++) { + configBuilder.setClusterSuccessor(0, i); + auto succType = + ZL_Compressor_Graph_getInput0Mask(compressor_, successors_[i]); + // If the type is serial, allow automatic conversion from numeric/struct + if (succType & 0b1) { + succType = (ZL_Type)(succType | 0b110); + } + if (!(type & succType) || typeToClusteringCodecIdxsMap_.count(type) == 0) { + continue; + } + auto clusteringCodecIdxs = typeToClusteringCodecIdxsMap_.at(type); + for (size_t j = 0; j < clusteringCodecIdxs.size(); j++) { + SizeTimePair cost{0, 0}; + configBuilder.setClusteringCodec(0, clusteringCodecIdxs[j]); + auto config = configBuilder.build(); + for (auto &sample : samples_) { + cost = cost + compressSample(config, filter, sample); + } + if (cost < bestClusterInfo.cost) { + bestClusterInfo = {.successorIdx = i, + .clusteringCodecIdx = clusteringCodecIdxs[j], + .cost = cost}; + } } - return bestClusterInfo; + } + return bestClusterInfo; } -SizeTimePair CompressionUtils::compressSample( - const ClusteringConfig& config, - const std::function& filter, - const MultiInput& sample) const -{ - // Set up local params for clustering - uint8_t* dst = NULL; - size_t dstSize = 0; - auto arena = detail::NonNullUniqueCPtr( - ALLOC_HeapArena_create(), ALLOC_Arena_freeArena); - A1C_Arena a1cArena = A1C_Arena_wrap(arena.get()); - openzl::CCtx cctx; - auto errCtx = ZL_CCtx_getOperationContext(cctx.get())->defaultScopeContext; - cctx.unwrap( - ZL_Clustering_serializeClusteringConfig( - errCtx, &dst, &dstSize, config.get(), &a1cArena), - "Failed to serialize clustering config"); - ZL_IntParam sizeParam = (ZL_IntParam){ - .paramId = ZL_GENERIC_CLUSTERING_CONFIG_SIZE_ID, - .paramValue = (int)dstSize, - }; - ZL_CopyParam configParam = (ZL_CopyParam){ - .paramId = ZL_GENERIC_CLUSTERING_CONFIG_ID, - .paramPtr = dst, - .paramSize = dstSize, - }; - ZL_LocalParams clusteringParams = (ZL_LocalParams){ - .intParams = { .intParams = &sizeParam, .nbIntParams = 1 }, - .copyParams = { .copyParams = &configParam, .nbCopyParams = 1 }, - }; - ZL_RuntimeGraphParameters runtimeParams = (ZL_RuntimeGraphParameters){ - .customGraphs = successors_.data(), - .nbCustomGraphs = successors_.size(), - .customNodes = clusteringCodecs_.data(), - .nbCustomNodes = clusteringCodecs_.size(), - .localParams = &clusteringParams, - }; +SizeTimePair +CompressionUtils::compressSample(const ClusteringConfig &config, + const std::function &filter, + const MultiInput &sample) const { + // Set up local params for clustering + uint8_t *dst = NULL; + size_t dstSize = 0; + auto arena = detail::NonNullUniqueCPtr(ALLOC_HeapArena_create(), + ALLOC_Arena_freeArena); + A1C_Arena a1cArena = A1C_Arena_wrap(arena.get()); + openzl::CCtx cctx; + auto errCtx = ZL_CCtx_getOperationContext(cctx.get())->defaultScopeContext; + cctx.unwrap(ZL_Clustering_serializeClusteringConfig(errCtx, &dst, &dstSize, + config.get(), &a1cArena), + "Failed to serialize clustering config"); + ZL_IntParam sizeParam = (ZL_IntParam){ + .paramId = ZL_GENERIC_CLUSTERING_CONFIG_SIZE_ID, + .paramValue = (int)dstSize, + }; + ZL_CopyParam configParam = (ZL_CopyParam){ + .paramId = ZL_GENERIC_CLUSTERING_CONFIG_ID, + .paramPtr = dst, + .paramSize = dstSize, + }; + ZL_LocalParams clusteringParams = (ZL_LocalParams){ + .intParams = {.intParams = &sizeParam, .nbIntParams = 1}, + .copyParams = {.copyParams = &configParam, .nbCopyParams = 1}, + }; + ZL_RuntimeGraphParameters runtimeParams = (ZL_RuntimeGraphParameters){ + .customGraphs = successors_.data(), + .nbCustomGraphs = successors_.size(), + .customNodes = clusteringCodecs_.data(), + .nbCustomNodes = clusteringCodecs_.size(), + .localParams = &clusteringParams, + }; - cctx.unwrap(ZL_CCtx_selectStartingGraphID( - cctx.get(), compressor_, ZL_GRAPH_CLUSTERING, &runtimeParams)); - cctx.setParameter(openzl::CParam::FormatVersion, ZL_MAX_FORMAT_VERSION); - size_t compressBound = 0; - std::vector constInputs; - for (auto& input : *sample) { - auto column = getColumnInfo(input); - if (!filter(column)) { - continue; - } - compressBound += ZL_compressBound( - (input.contentSize() + input.numElts() * 4) - * compressBoundFactor_); - constInputs.push_back(input.get()); + cctx.unwrap(ZL_CCtx_selectStartingGraphID( + cctx.get(), compressor_, ZL_GRAPH_CLUSTERING, &runtimeParams)); + cctx.setParameter(openzl::CParam::FormatVersion, ZL_MAX_FORMAT_VERSION); + size_t compressBound = 0; + std::vector constInputs; + for (auto &input : *sample) { + auto column = getColumnInfo(input); + if (!filter(column)) { + continue; } - if (constInputs.empty()) { - return (SizeTimePair){ 0, 0 }; + compressBound += ZL_compressBound( + (input.contentSize() + input.numElts() * 4) * compressBoundFactor_); + constInputs.push_back(input.get()); + } + if (constInputs.empty()) { + return (SizeTimePair){0, 0}; + } + std::string compressed(compressBound, 0); + auto start = std::chrono::high_resolution_clock::now(); + ZL_Report csize = ZL_CCtx_compressMultiTypedRef( + cctx.get(), compressed.data(), compressed.size(), constInputs.data(), + constInputs.size()); + auto stop = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(stop - start); + // TODO: T231098760: This implementation is a hack to get around the current + // state of csv successors + if (ZL_isError(csize)) { + static std::atomic errorLogged{false}; + if (!errorLogged.exchange(true)) { + // Only log the first occurrence of this error + ZL_LOG(ERROR, "Selected a successor that fails to compress on input, " + "treating this as a candidate with a large compression " + "cost. Suppressing future logs for this error."); } - std::string compressed(compressBound, 0); - auto start = std::chrono::high_resolution_clock::now(); - ZL_Report csize = ZL_CCtx_compressMultiTypedRef( - cctx.get(), - compressed.data(), - compressed.size(), - constInputs.data(), - constInputs.size()); - auto stop = std::chrono::high_resolution_clock::now(); - auto duration = - std::chrono::duration_cast(stop - start); - // TODO: T231098760: This implementation is a hack to get around the current - // state of csv successors - if (ZL_isError(csize)) { - static std::atomic errorLogged{ false }; - if (!errorLogged.exchange(true)) { - // Only log the first occurrence of this error - ZL_LOG(ERROR, - "Selected a successor that fails to compress on input, treating this as a candidate with a large compression cost. Suppressing future logs for this error."); - } - return SizeTimePair{ std::numeric_limits::max(), - std::numeric_limits::max() }; - } - cctx.unwrap(csize); - return SizeTimePair{ ZL_RES_value(csize), (size_t)duration.count() }; + return SizeTimePair{std::numeric_limits::max(), + std::numeric_limits::max()}; + } + cctx.unwrap(csize); + return SizeTimePair{ZL_RES_value(csize), (size_t)duration.count()}; } std::future CompressionUtils::tryCompress( - const ClusteringConfig& config, - const std::function& filter) const -{ - std::vector> futures; - // Copy clusteringConfig and filter into memory owned by ptrs and pass - // shared_ptrs into functions - auto configPtr = std::make_shared(*config); - auto funcPtr = - std::make_shared>(filter); - for (size_t i = 0; i < samples_.size(); i++) { - auto task = [this, - i](std::shared_ptr ccPtr, - std::shared_ptr> - fPtr) { - return compressSample(*ccPtr, *fPtr, samples_.at(i)); + const ClusteringConfig &config, + const std::function &filter) const { + std::vector> futures; + // Copy clusteringConfig and filter into memory owned by ptrs and pass + // shared_ptrs into functions + auto configPtr = std::make_shared(*config); + auto funcPtr = + std::make_shared>(filter); + for (size_t i = 0; i < samples_.size(); i++) { + auto task = + [this, i](std::shared_ptr ccPtr, + std::shared_ptr> fPtr) { + return compressSample(*ccPtr, *fPtr, samples_.at(i)); }; - futures.emplace_back(threadPool_->run(task, configPtr, funcPtr)); - } - // Make the lambda copyable for MSVC's packaged_task by capturing a shared_ptr: - auto futures_ptr = - std::make_shared>>(std::move(futures)); + futures.emplace_back(threadPool_->run(task, configPtr, funcPtr)); + } + // Make the lambda copyable for MSVC's packaged_task by capturing a + // shared_ptr: + auto futures_ptr = std::make_shared>>( + std::move(futures)); - return threadPool_->run([futures_ptr]() { - SizeTimePair result{}; - for (auto& fut : *futures_ptr) { - result = result + fut.get(); - } - return result; - }); + return threadPool_->run([futures_ptr]() { + SizeTimePair result{}; + for (auto &fut : *futures_ptr) { + result = result + fut.get(); + } + return result; + }); } -ColumnMetadata CompressionUtils::aggregateInputMetadata() const -{ - // TODO: Tags need to no longer uniquely idenify an input - ColumnMetadata metadata; - for (auto& sample : samples_) { - for (auto& input : *sample) { - metadata.insert(getColumnInfo(input)); - } +ColumnMetadata CompressionUtils::aggregateInputMetadata() const { + // TODO: Tags need to no longer uniquely idenify an input + ColumnMetadata metadata; + for (auto &sample : samples_) { + for (auto &input : *sample) { + metadata.insert(getColumnInfo(input)); } - return metadata; + } + return metadata; } } // namespace openzl::training diff --git a/tools/training/utils/thread_pool.h b/tools/training/utils/thread_pool.h index a703eab4..260ef3c0 100644 --- a/tools/training/utils/thread_pool.h +++ b/tools/training/utils/thread_pool.h @@ -5,68 +5,69 @@ #include #include #include +#include +#include +#include #include #include #include -#include -#include #include -#include namespace openzl::training { class ThreadPool { - public: - explicit ThreadPool(size_t numThreads_); +public: + explicit ThreadPool(size_t numThreads_); - ThreadPool(const ThreadPool&) = delete; - ThreadPool(ThreadPool&&) = delete; - ThreadPool& operator=(const ThreadPool&) = delete; - ThreadPool& operator=(ThreadPool&&) = delete; + ThreadPool(const ThreadPool &) = delete; + ThreadPool(ThreadPool &&) = delete; + ThreadPool &operator=(const ThreadPool &) = delete; + ThreadPool &operator=(ThreadPool &&) = delete; - ~ThreadPool(); + ~ThreadPool(); - /** - * This function accepts a callable object and its arguments, packages them - * into a task, and enqueues the task for execution by the thread pool. It - * returns a `std::future` object that can be used to retrieve the result of - * the task once it has been executed. This function is intended to be used - * for running asynchronous tasks in parallel. - * - * @param func The callable object to be executed. - * @param args The arguments to be passed to the callable object. - * @return std::future A future object that holds the result of - * the task. - */ - template - auto run(F&& f, Args&&... args) - -> std::future, std::decay_t...>> { - using ReturnType = std::invoke_result_t, std::decay_t...>; - using Task = std::packaged_task; + /** + * This function accepts a callable object and its arguments, packages them + * into a task, and enqueues the task for execution by the thread pool. It + * returns a `std::future` object that can be used to retrieve the result of + * the task once it has been executed. This function is intended to be used + * for running asynchronous tasks in parallel. + * + * @param func The callable object to be executed. + * @param args The arguments to be passed to the callable object. + * @return std::future A future object that holds the result of + * the task. + */ + template + auto run(F &&f, Args &&...args) -> std::future< + std::invoke_result_t, std::decay_t...>> { + using ReturnType = + std::invoke_result_t, std::decay_t...>; + using Task = std::packaged_task; - // Capture callable + args by MOVE into a tuple (no copies of move-only types). - auto bound = [fn = std::forward(f), - tup = std::make_tuple(std::forward(args)...)]() mutable -> ReturnType { - return std::apply(fn, tup); - }; + // Capture callable + args by MOVE into a tuple (no copies of move-only + // types). + auto bound = [fn = std::forward(f), + tup = std::make_tuple(std::forward(args)...)]() mutable + -> ReturnType { return std::apply(fn, tup); }; - auto task = std::make_shared(std::move(bound)); - auto fut = task->get_future(); + auto task = std::make_shared(std::move(bound)); + auto fut = task->get_future(); - { - std::lock_guard lock(queueMutex_); // correct name - taskQueue_.emplace([task]() mutable { (*task)(); }); // correct name - } - condition_.notify_one(); // correct name - return fut; + { + std::lock_guard lock(queueMutex_); // correct name + taskQueue_.emplace([task]() mutable { (*task)(); }); // correct name } + condition_.notify_one(); // correct name + return fut; + } - const size_t numThreads; + const size_t numThreads; - private: - std::mutex queueMutex_; - std::queue> taskQueue_; - std::vector threads_; - std::condition_variable condition_; - bool stop_ = false; +private: + std::mutex queueMutex_; + std::queue> taskQueue_; + std::vector threads_; + std::condition_variable condition_; + bool stop_ = false; }; } // namespace openzl::training