From 747f78a91dc5559d5ae1b5d933d527c4dad34c3f Mon Sep 17 00:00:00 2001 From: Nathan Delisle Date: Wed, 8 Oct 2025 21:16:26 -0500 Subject: [PATCH 1/2] tools: add missing STL includes for Windows/clang-cl (bugfix) --- tools/io/InputSet.cpp | 2 ++ tools/io/InputSetBuilder.h | 4 ++++ tools/io/InputSetDir.cpp | 2 ++ tools/io/InputSetMulti.cpp | 2 ++ tools/io/InputSetStatic.cpp | 2 ++ tools/logger/Logger.h | 1 + tools/training/clustering/trainers/bottom_up_trainer.cpp | 2 ++ tools/training/clustering/trainers/greedy_trainer.cpp | 1 + tools/training/utils/genetic_algorithm.cpp | 2 ++ 9 files changed, 18 insertions(+) diff --git a/tools/io/InputSet.cpp b/tools/io/InputSet.cpp index 8a6d5850..0d6e26a8 100644 --- a/tools/io/InputSet.cpp +++ b/tools/io/InputSet.cpp @@ -1,6 +1,8 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. #include "tools/io/InputSet.h" +#include + namespace openzl::tools::io { diff --git a/tools/io/InputSetBuilder.h b/tools/io/InputSetBuilder.h index cd4b30a5..7a308682 100644 --- a/tools/io/InputSetBuilder.h +++ b/tools/io/InputSetBuilder.h @@ -4,6 +4,10 @@ #include #include +#include +#include +#include + #include "tools/io/InputSet.h" diff --git a/tools/io/InputSetDir.cpp b/tools/io/InputSetDir.cpp index 64c89530..26507474 100644 --- a/tools/io/InputSetDir.cpp +++ b/tools/io/InputSetDir.cpp @@ -3,6 +3,8 @@ #include "tools/io/InputSetDir.h" #include +#include + #include "tools/io/InputFile.h" diff --git a/tools/io/InputSetMulti.cpp b/tools/io/InputSetMulti.cpp index 04e18ba5..d1199272 100644 --- a/tools/io/InputSetMulti.cpp +++ b/tools/io/InputSetMulti.cpp @@ -1,6 +1,8 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. #include "tools/io/InputSetMulti.h" +#include + namespace openzl::tools::io { diff --git a/tools/io/InputSetStatic.cpp b/tools/io/InputSetStatic.cpp index 00ef2e80..c087a7c4 100644 --- a/tools/io/InputSetStatic.cpp +++ b/tools/io/InputSetStatic.cpp @@ -1,6 +1,8 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. #include "tools/io/InputSetStatic.h" +#include + namespace openzl::tools::io { diff --git a/tools/logger/Logger.h b/tools/logger/Logger.h index 18527623..1d34b3a6 100644 --- a/tools/logger/Logger.h +++ b/tools/logger/Logger.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include diff --git a/tools/training/clustering/trainers/bottom_up_trainer.cpp b/tools/training/clustering/trainers/bottom_up_trainer.cpp index 4f72d569..7f62237d 100644 --- a/tools/training/clustering/trainers/bottom_up_trainer.cpp +++ b/tools/training/clustering/trainers/bottom_up_trainer.cpp @@ -1,5 +1,7 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. +#include + #include "tools/training/clustering/trainers/bottom_up_trainer.h" #include "tools/logger/Logger.h" #include "tools/training/clustering/clustering_config_builder.h" diff --git a/tools/training/clustering/trainers/greedy_trainer.cpp b/tools/training/clustering/trainers/greedy_trainer.cpp index 66b72379..48e39504 100644 --- a/tools/training/clustering/trainers/greedy_trainer.cpp +++ b/tools/training/clustering/trainers/greedy_trainer.cpp @@ -3,6 +3,7 @@ #include "tools/training/clustering/trainers/greedy_trainer.h" #include +#include #include "tools/logger/Logger.h" #include "tools/training/clustering/clustering_config_builder.h" diff --git a/tools/training/utils/genetic_algorithm.cpp b/tools/training/utils/genetic_algorithm.cpp index f2a6aae1..1aa1a5c6 100644 --- a/tools/training/utils/genetic_algorithm.cpp +++ b/tools/training/utils/genetic_algorithm.cpp @@ -1,5 +1,7 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. +#include + #include "tools/training/utils/genetic_algorithm.h" namespace openzl { From 91654a98f0fec265695d5d960cc0dac44da6384d Mon Sep 17 00:00:00 2001 From: Nathan Delisle Date: Wed, 8 Oct 2025 21:33:47 -0500 Subject: [PATCH 2/2] style: clang-format --- tools/io/InputSet.cpp | 84 ++--- tools/io/InputSetBuilder.h | 29 +- tools/io/InputSetDir.cpp | 124 ++++--- tools/io/InputSetMulti.cpp | 145 ++++---- tools/io/InputSetStatic.cpp | 94 +++-- tools/logger/Logger.h | 322 ++++++++---------- .../clustering/trainers/bottom_up_trainer.cpp | 272 +++++++-------- 7 files changed, 481 insertions(+), 589 deletions(-) diff --git a/tools/io/InputSet.cpp b/tools/io/InputSet.cpp index 0d6e26a8..aa0056b1 100644 --- a/tools/io/InputSet.cpp +++ b/tools/io/InputSet.cpp @@ -3,78 +3,60 @@ #include "tools/io/InputSet.h" #include - namespace openzl::tools::io { -InputSet::Iterator InputSet::begin() const -{ - return Iterator{ begin_state() }; -} +InputSet::Iterator InputSet::begin() const { return Iterator{begin_state()}; } -InputSet::Iterator InputSet::end() const -{ - return Iterator{}; -} +InputSet::Iterator InputSet::end() const { return Iterator{}; } // end() InputSet::Iterator::Iterator() : state_() {} // begin() InputSet::Iterator::Iterator(std::unique_ptr state) - : state_((state && **state) ? std::move(state) - : std::unique_ptr{}) -{ -} + : state_((state && **state) ? std::move(state) + : std::unique_ptr{}) {} -InputSet::Iterator::Iterator(const Iterator& o) - : state_(o.state_ ? o.state_->copy() : std::unique_ptr{}) -{ -} +InputSet::Iterator::Iterator(const Iterator &o) + : state_(o.state_ ? o.state_->copy() : std::unique_ptr{}) {} -InputSet::Iterator& InputSet::Iterator::operator=(const Iterator& o) -{ - state_ = o.state_ ? o.state_->copy() : std::unique_ptr{}; - return *this; +InputSet::Iterator &InputSet::Iterator::operator=(const Iterator &o) { + state_ = o.state_ ? o.state_->copy() : std::unique_ptr{}; + return *this; } -const std::shared_ptr& InputSet::Iterator::operator*() const -{ - static const std::shared_ptr null_input{}; - if (!state_) { - throw std::runtime_error("Can't deref end InputSet::Iterator."); - } - return **state_; +const std::shared_ptr &InputSet::Iterator::operator*() const { + static const std::shared_ptr null_input{}; + if (!state_) { + throw std::runtime_error("Can't deref end InputSet::Iterator."); + } + return **state_; } -InputSet::Iterator& InputSet::Iterator::operator++() -{ - if (!state_) { - throw std::runtime_error( - "Can't advance InputSet::Iterator past the end of the InputSet."); - } - ++(*state_); - if (!**state_) { - state_.reset(); - } - return *this; +InputSet::Iterator &InputSet::Iterator::operator++() { + if (!state_) { + throw std::runtime_error( + "Can't advance InputSet::Iterator past the end of the InputSet."); + } + ++(*state_); + if (!**state_) { + state_.reset(); + } + return *this; } -InputSet::Iterator InputSet::Iterator::operator++(int) const -{ - Iterator new_it{ *this }; - ++new_it; - return new_it; +InputSet::Iterator InputSet::Iterator::operator++(int) const { + Iterator new_it{*this}; + ++new_it; + return new_it; } -bool InputSet::Iterator::operator==(const Iterator& o) const -{ - return (!state_ && !o.state_) - || (state_ && o.state_ && *state_ == *o.state_); +bool InputSet::Iterator::operator==(const Iterator &o) const { + return (!state_ && !o.state_) || (state_ && o.state_ && *state_ == *o.state_); } -bool InputSet::Iterator::operator!=(const Iterator& o) const -{ - return !(*this == o); +bool InputSet::Iterator::operator!=(const Iterator &o) const { + return !(*this == o); } } // namespace openzl::tools::io diff --git a/tools/io/InputSetBuilder.h b/tools/io/InputSetBuilder.h index 7a308682..fe8ed611 100644 --- a/tools/io/InputSetBuilder.h +++ b/tools/io/InputSetBuilder.h @@ -3,11 +3,10 @@ #pragma once #include -#include -#include #include #include - +#include +#include #include "tools/io/InputSet.h" @@ -17,23 +16,23 @@ namespace openzl::tools::io { * Helper class to build up an input set from a bunch of path arguments. */ class InputSetBuilder { - public: - explicit InputSetBuilder(bool recursive, bool verbose = false); +public: + explicit InputSetBuilder(bool recursive, bool verbose = false); - InputSetBuilder& add_path(std::string path) &; - InputSetBuilder&& add_path(std::string path) &&; + InputSetBuilder &add_path(std::string path) &; + InputSetBuilder &&add_path(std::string path) &&; - InputSetBuilder& add_path(std::optional path_opt) &; - InputSetBuilder&& add_path(std::optional path_opt) &&; + InputSetBuilder &add_path(std::optional path_opt) &; + InputSetBuilder &&add_path(std::optional path_opt) &&; - std::unique_ptr build() &&; + std::unique_ptr build() &&; - std::unique_ptr build_static() &&; + std::unique_ptr build_static() &&; - private: - const bool recursive_; - const bool verbose_; +private: + const bool recursive_; + const bool verbose_; - std::vector> input_sets_; + std::vector> input_sets_; }; } // namespace openzl::tools::io diff --git a/tools/io/InputSetDir.cpp b/tools/io/InputSetDir.cpp index 26507474..1f0e938b 100644 --- a/tools/io/InputSetDir.cpp +++ b/tools/io/InputSetDir.cpp @@ -5,93 +5,83 @@ #include #include - #include "tools/io/InputFile.h" namespace openzl::tools::io { template class InputSetDir::IteratorStateDir : public InputSet::IteratorState { - public: - IteratorStateDir(const InputSetDir& isd, const std::string& path) - : isd_(isd), it_(path) - { - advance_to_next_regular_file(); - } +public: + IteratorStateDir(const InputSetDir &isd, const std::string &path) + : isd_(isd), it_(path) { + advance_to_next_regular_file(); + } - std::unique_ptr copy() const override - { - return std::make_unique(*this); - } + std::unique_ptr copy() const override { + return std::make_unique(*this); + } - // Advance to the next regular file. - IteratorState& operator++() override - { - input_.reset(); - if (it_ == IterT{}) { - throw std::runtime_error( - "Can't advance iterator past the end of the InputSet."); - } - ++it_; - advance_to_next_regular_file(); - return *this; + // Advance to the next regular file. + IteratorState &operator++() override { + input_.reset(); + if (it_ == IterT{}) { + throw std::runtime_error( + "Can't advance iterator past the end of the InputSet."); } + ++it_; + advance_to_next_regular_file(); + return *this; + } - const std::shared_ptr& operator*() const override - { - if (input_) { - return input_; - } - if (it_ == IterT{}) { - return input_; - } - input_ = std::make_shared(it_->path().string()); - return input_; + const std::shared_ptr &operator*() const override { + if (input_) { + return input_; } + if (it_ == IterT{}) { + return input_; + } + input_ = std::make_shared(it_->path().string()); + return input_; + } - bool operator==(const IteratorState& o) const override - { - auto ptr = dynamic_cast(&o); - if (ptr == nullptr) { - return false; - } - return (&isd_ == &ptr->isd_) && (it_ == ptr->it_); + bool operator==(const IteratorState &o) const override { + auto ptr = dynamic_cast(&o); + if (ptr == nullptr) { + return false; } + return (&isd_ == &ptr->isd_) && (it_ == ptr->it_); + } - private: - void advance_to_next_regular_file() - { - while (true) { - if (it_ == IterT{}) { - break; - } - if (it_->is_regular_file()) { - break; - } - ++it_; - } +private: + void advance_to_next_regular_file() { + while (true) { + if (it_ == IterT{}) { + break; + } + if (it_->is_regular_file()) { + break; + } + ++it_; } + } - const InputSetDir& isd_; - IterT it_; - mutable std::shared_ptr input_; + const InputSetDir &isd_; + IterT it_; + mutable std::shared_ptr input_; }; InputSetDir::InputSetDir(std::string path, bool recursive) - : path_(std::move(path)), recursive_(recursive) -{ -} + : path_(std::move(path)), recursive_(recursive) {} -std::unique_ptr InputSetDir::begin_state() const -{ - if (recursive_) { - return std::make_unique>(*this, path_); - } else { - return std::make_unique< - IteratorStateDir>( - *this, path_); - } +std::unique_ptr InputSetDir::begin_state() const { + if (recursive_) { + return std::make_unique< + IteratorStateDir>(*this, + path_); + } else { + return std::make_unique< + IteratorStateDir>(*this, path_); + } } } // namespace openzl::tools::io diff --git a/tools/io/InputSetMulti.cpp b/tools/io/InputSetMulti.cpp index d1199272..7016b11d 100644 --- a/tools/io/InputSetMulti.cpp +++ b/tools/io/InputSetMulti.cpp @@ -3,105 +3,94 @@ #include "tools/io/InputSetMulti.h" #include - namespace openzl::tools::io { class InputSetMulti::IteratorStateMulti : public InputSet::IteratorState { - public: - explicit IteratorStateMulti(const InputSetMulti& ism) - : ism_(ism), idx_(0), size_(ism.input_sets_.size()) - { - if (idx_ != size_) { - inner_it_ = ism_[idx_].begin(); - inner_end_ = ism_[idx_].end(); - } - advance_to_next_nonempty(); +public: + explicit IteratorStateMulti(const InputSetMulti &ism) + : ism_(ism), idx_(0), size_(ism.input_sets_.size()) { + if (idx_ != size_) { + inner_it_ = ism_[idx_].begin(); + inner_end_ = ism_[idx_].end(); } + advance_to_next_nonempty(); + } - std::unique_ptr copy() const override - { - return std::make_unique(*this); - } + std::unique_ptr copy() const override { + return std::make_unique(*this); + } - IteratorState& operator++() override - { - if (idx_ == size_) { - throw std::runtime_error( - "Can't advance iterator past the end of the InputSet."); - } - if (inner_it_ != inner_end_) { - ++inner_it_; - } - if (inner_it_ == inner_end_) { - idx_++; - if (idx_ != size_) { - inner_it_ = ism_[idx_].begin(); - inner_end_ = ism_[idx_].end(); - } - advance_to_next_nonempty(); - } - return *this; + IteratorState &operator++() override { + if (idx_ == size_) { + throw std::runtime_error( + "Can't advance iterator past the end of the InputSet."); + } + if (inner_it_ != inner_end_) { + ++inner_it_; + } + if (inner_it_ == inner_end_) { + idx_++; + if (idx_ != size_) { + inner_it_ = ism_[idx_].begin(); + inner_end_ = ism_[idx_].end(); + } + advance_to_next_nonempty(); } + return *this; + } - const std::shared_ptr& operator*() const override - { - static const std::shared_ptr null_input{}; - if (idx_ == ism_.input_sets_.size()) { - return null_input; - } - if (inner_it_ == inner_end_) { - return null_input; - } - return *inner_it_; + const std::shared_ptr &operator*() const override { + static const std::shared_ptr null_input{}; + if (idx_ == ism_.input_sets_.size()) { + return null_input; } + if (inner_it_ == inner_end_) { + return null_input; + } + return *inner_it_; + } - bool operator==(const IteratorState& o) const override - { - auto ptr = dynamic_cast(&o); - if (ptr == nullptr) { - return false; - } - return (&ism_ == &ptr->ism_) && (idx_ == ptr->idx_) - && (inner_it_ == ptr->inner_it_); + bool operator==(const IteratorState &o) const override { + auto ptr = dynamic_cast(&o); + if (ptr == nullptr) { + return false; } + return (&ism_ == &ptr->ism_) && (idx_ == ptr->idx_) && + (inner_it_ == ptr->inner_it_); + } - private: - void advance_to_next_nonempty() - { - if (idx_ == size_) { - return; - } - while (inner_it_ == inner_end_) { - idx_++; - if (idx_ == size_) { - break; - } - inner_it_ = ism_[idx_].begin(); - inner_end_ = ism_[idx_].end(); - } +private: + void advance_to_next_nonempty() { + if (idx_ == size_) { + return; + } + while (inner_it_ == inner_end_) { + idx_++; + if (idx_ == size_) { + break; + } + inner_it_ = ism_[idx_].begin(); + inner_end_ = ism_[idx_].end(); } + } - const InputSetMulti& ism_; - size_t idx_; - const size_t size_; + const InputSetMulti &ism_; + size_t idx_; + const size_t size_; - InputSet::Iterator inner_it_; - InputSet::Iterator inner_end_; + InputSet::Iterator inner_it_; + InputSet::Iterator inner_end_; }; InputSetMulti::InputSetMulti(std::vector> input_sets) - : input_sets_(std::move(input_sets)) -{ -} + : input_sets_(std::move(input_sets)) {} -const InputSet& InputSetMulti::operator[](size_t idx) const -{ - return *input_sets_[idx]; +const InputSet &InputSetMulti::operator[](size_t idx) const { + return *input_sets_[idx]; } -std::unique_ptr InputSetMulti::begin_state() const -{ - return std::make_unique(*this); +std::unique_ptr InputSetMulti::begin_state() const { + return std::make_unique(*this); } } // namespace openzl::tools::io diff --git a/tools/io/InputSetStatic.cpp b/tools/io/InputSetStatic.cpp index c087a7c4..e9601e73 100644 --- a/tools/io/InputSetStatic.cpp +++ b/tools/io/InputSetStatic.cpp @@ -3,80 +3,66 @@ #include "tools/io/InputSetStatic.h" #include - namespace openzl::tools::io { class InputSetStatic::IteratorStateStatic : public InputSet::IteratorState { - public: - explicit IteratorStateStatic(const InputSetStatic& iss, size_t idx) - : iss_(iss), idx_(idx) - { - } +public: + explicit IteratorStateStatic(const InputSetStatic &iss, size_t idx) + : iss_(iss), idx_(idx) {} - std::unique_ptr copy() const override - { - return std::make_unique(iss_, idx_); - } + std::unique_ptr copy() const override { + return std::make_unique(iss_, idx_); + } - IteratorState& operator++() override - { - idx_++; - return *this; - } + IteratorState &operator++() override { + idx_++; + return *this; + } - const std::shared_ptr& operator*() const override - { - return iss_[idx_]; - } + const std::shared_ptr &operator*() const override { + return iss_[idx_]; + } - bool operator==(const IteratorState& o) const override - { - auto ptr = dynamic_cast(&o); - if (ptr == nullptr) { - return false; - } - return (&iss_ == &ptr->iss_) && (idx_ == ptr->idx_); + bool operator==(const IteratorState &o) const override { + auto ptr = dynamic_cast(&o); + if (ptr == nullptr) { + return false; } + return (&iss_ == &ptr->iss_) && (idx_ == ptr->idx_); + } - private: - const InputSetStatic& iss_; - size_t idx_{ 0 }; +private: + const InputSetStatic &iss_; + size_t idx_{0}; }; InputSetStatic::InputSetStatic(std::vector> inputs) - : inputs_(std::move(inputs)) -{ - for (const auto& input : inputs_) { - if (!input) { - throw std::runtime_error("InputSetStatic cannot hold a nullptr."); - } + : inputs_(std::move(inputs)) { + for (const auto &input : inputs_) { + if (!input) { + throw std::runtime_error("InputSetStatic cannot hold a nullptr."); } + } } -InputSetStatic InputSetStatic::from_input_set(InputSet& input_set) -{ - std::vector> inputs{ input_set.begin(), - input_set.end() }; - return InputSetStatic{ std::move(inputs) }; +InputSetStatic InputSetStatic::from_input_set(InputSet &input_set) { + std::vector> inputs{input_set.begin(), + input_set.end()}; + return InputSetStatic{std::move(inputs)}; } -size_t InputSetStatic::size() const -{ - return inputs_.size(); -} +size_t InputSetStatic::size() const { return inputs_.size(); } -const std::shared_ptr& InputSetStatic::operator[](size_t idx) const -{ - static const std::shared_ptr null_input{}; - if (idx < inputs_.size()) { - return inputs_[idx]; - } - return null_input; +const std::shared_ptr &InputSetStatic::operator[](size_t idx) const { + static const std::shared_ptr null_input{}; + if (idx < inputs_.size()) { + return inputs_[idx]; + } + return null_input; } -std::unique_ptr InputSetStatic::begin_state() const -{ - return std::make_unique(*this, 0); +std::unique_ptr InputSetStatic::begin_state() const { + return std::make_unique(*this, 0); } } // namespace openzl::tools::io diff --git a/tools/logger/Logger.h b/tools/logger/Logger.h index 1d34b3a6..5f67e9de 100644 --- a/tools/logger/Logger.h +++ b/tools/logger/Logger.h @@ -5,22 +5,22 @@ #include #include #include -#include #include +#include #include namespace openzl::tools::logger { // Global verbosity functions enum LogLevel { - ALWAYS = 0, - ERRORS = 1, - WARNINGS = 2, - INFO = 3, - VERBOSE1 = 4, - VERBOSE2 = 5, - VERBOSE3 = 6, - EVERYTHING = 7, + ALWAYS = 0, + ERRORS = 1, + WARNINGS = 2, + INFO = 3, + VERBOSE1 = 4, + VERBOSE2 = 5, + VERBOSE3 = 6, + EVERYTHING = 7, }; const int progressBarWidth = 50; @@ -29,200 +29,176 @@ const int progressBarWidth = 50; * Logger class for CLI and training */ class Logger { - public: - static Logger& instance() - { - static Logger instance_; - return instance_; - } - - int global_verbosity; // TODO where to set this by default? - bool progress_line_active; // Track if we have an active progress line - - // Store current progress information for re-printing - LogLevel progress_level; - double progress_value; - std::string progress_message; - - void setGlobalLoggerVerbosity(int verbosity) - { - if (verbosity < static_cast(ALWAYS) - || verbosity > static_cast(EVERYTHING)) { - throw std::invalid_argument( - "Invalid log level: " + std::to_string(verbosity) - + ". Valid levels are " - + std::to_string(static_cast(ALWAYS)) + " (ALWAYS) to " - + std::to_string(static_cast(EVERYTHING)) - + " (EVERYTHING)."); - } - global_verbosity = verbosity; +public: + static Logger &instance() { + static Logger instance_; + return instance_; + } + + int global_verbosity; // TODO where to set this by default? + bool progress_line_active; // Track if we have an active progress line + + // Store current progress information for re-printing + LogLevel progress_level; + double progress_value; + std::string progress_message; + + void setGlobalLoggerVerbosity(int verbosity) { + if (verbosity < static_cast(ALWAYS) || + verbosity > static_cast(EVERYTHING)) { + throw std::invalid_argument( + "Invalid log level: " + std::to_string(verbosity) + + ". Valid levels are " + std::to_string(static_cast(ALWAYS)) + + " (ALWAYS) to " + std::to_string(static_cast(EVERYTHING)) + + " (EVERYTHING)."); } - int getGlobalLoggerVerbosity() - { - return global_verbosity; + global_verbosity = verbosity; + } + int getGlobalLoggerVerbosity() { return global_verbosity; } + +private: + Logger() + : global_verbosity(INFO), progress_line_active(false), + progress_level(INFO), progress_value(0.0), progress_message("") {} + Logger(const Logger &) = delete; + Logger &operator=(const Logger &) = delete; + + template + static void log_inner(std::ostream &os, const Arg &arg) { + os << arg; + } + template + static void log_inner(std::ostream &os, const Arg &arg, const Args &...args) { + log_inner(os << arg, args...); + } + +public: + template + static void log(LogLevel level, const Args &...args) { + if (shouldLog(level)) { + finalizeProgressIfActive(); + log_inner(std::cerr, args...); + std::cerr << '\n'; + reprintProgressIfActive(); } + } - private: - Logger() - : global_verbosity(INFO), - progress_line_active(false), - progress_level(INFO), - progress_value(0.0), - progress_message("") - { +public: + template + static void log_c(LogLevel level, const char *format, const Args &...args) { + if (!shouldLog(level)) { + return; } - Logger(const Logger&) = delete; - Logger& operator=(const Logger&) = delete; - - template - static void log_inner(std::ostream& os, const Arg& arg) - { - os << arg; - } - template - static void log_inner(std::ostream& os, const Arg& arg, const Args&... args) - { - log_inner(os << arg, args...); - } - - public: - template - static void log(LogLevel level, const Args&... args) - { - if (shouldLog(level)) { - finalizeProgressIfActive(); - log_inner(std::cerr, args...); - std::cerr << '\n'; - reprintProgressIfActive(); - } - } - - public: - template - static void log_c(LogLevel level, const char* format, const Args&... args) - { - if (!shouldLog(level)) { - return; - } - finalizeProgressIfActive(); + finalizeProgressIfActive(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wformat-nonliteral" - fprintf(stderr, format, args...); - fprintf(stderr, "\n"); + fprintf(stderr, format, args...); + fprintf(stderr, "\n"); #pragma GCC diagnostic pop - reprintProgressIfActive(); - } + reprintProgressIfActive(); + } - public: - template - static void update(LogLevel level, const char* format, const Args&... args) - { - if (!shouldLog(level)) { - return; - } +public: + template + static void update(LogLevel level, const char *format, const Args &...args) { + if (!shouldLog(level)) { + return; + } - // Move to beginning of line - // TODO remove control characters when printing to non-tty - fprintf(stderr, "\r"); + // Move to beginning of line + // TODO remove control characters when printing to non-tty + fprintf(stderr, "\r"); - // Print the formatted message + // Print the formatted message #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wformat-nonliteral" - fprintf(stderr, format, args...); + fprintf(stderr, format, args...); #pragma GCC diagnostic pop - // Clear to end of line to remove any remaining characters - // TODO remove control characters when printing to non-tty - fprintf(stderr, "%s", CLEAR_TO_EOL); + // Clear to end of line to remove any remaining characters + // TODO remove control characters when printing to non-tty + fprintf(stderr, "%s", CLEAR_TO_EOL); + + fflush(stderr); + } - fflush(stderr); + template + static void logProgress(LogLevel level, double progress, const char *format, + Args &&...args) { + if (!shouldLog(level)) { + return; } - template - static void logProgress( - LogLevel level, - double progress, - const char* format, - Args&&... args) - { - if (!shouldLog(level)) { - return; - } - - if (progress > 1.0) { - throw std::invalid_argument( - "Progress percentage must be <= 1.0, got: " - + std::to_string(progress) + "."); - } - - instance().progress_line_active = true; - - // Store current progress information for re-printing - instance().progress_level = level; - instance().progress_value = progress; - - // Build the user message part - std::string userMsg; - if (format && format[0] != '\0') { + if (progress > 1.0) { + throw std::invalid_argument("Progress percentage must be <= 1.0, got: " + + std::to_string(progress) + "."); + } + + instance().progress_line_active = true; + + // Store current progress information for re-printing + instance().progress_level = level; + instance().progress_value = progress; + + // Build the user message part + std::string userMsg; + if (format && format[0] != '\0') { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wformat-nonliteral" - int required_size = snprintf(nullptr, 0, format, args...); - if (required_size > 0) { - std::vector buffer( - required_size + 1); // +1 for null terminator - snprintf(buffer.data(), buffer.size(), format, args...); - userMsg = std::string(buffer.data()); - } + int required_size = snprintf(nullptr, 0, format, args...); + if (required_size > 0) { + std::vector buffer(required_size + 1); // +1 for null terminator + snprintf(buffer.data(), buffer.size(), format, args...); + userMsg = std::string(buffer.data()); + } #pragma GCC diagnostic pop - } - - // Build the progress message - int filled = (int)(progress * progressBarWidth); - char progressBar[progressBarWidth + 3]; // progressBarWidth + 2 for ends - // + 1 for null terminator - progressBar[0] = '['; - for (int i = 0; i < progressBarWidth; ++i) { - progressBar[i + 1] = (i < filled) ? '=' : '-'; - } - progressBar[progressBarWidth + 1] = ']'; - progressBar[progressBarWidth + 2] = '\0'; - instance().progress_message = std::string(progressBar) + " " + userMsg; - - update(level, "%s", instance().progress_message.c_str()); } - // Finalize an update line by adding a newline - public: - static void finalizeUpdate(LogLevel level) - { - if (!shouldLog(level)) { - return; - } - - // Add newline - fprintf(stderr, "\n"); + // Build the progress message + int filled = (int)(progress * progressBarWidth); + char progressBar[progressBarWidth + 3]; // progressBarWidth + 2 for ends + // + 1 for null terminator + progressBar[0] = '['; + for (int i = 0; i < progressBarWidth; ++i) { + progressBar[i + 1] = (i < filled) ? '=' : '-'; } - - // Finalize an UPDATE line by adding a newline - static void finalizeProgress(LogLevel level) - { - finalizeUpdate(level); - instance().progress_line_active = false; + progressBar[progressBarWidth + 1] = ']'; + progressBar[progressBarWidth + 2] = '\0'; + instance().progress_message = std::string(progressBar) + " " + userMsg; + + update(level, "%s", instance().progress_message.c_str()); + } + + // Finalize an update line by adding a newline +public: + static void finalizeUpdate(LogLevel level) { + if (!shouldLog(level)) { + return; } - private: - // ANSI terminal control sequences - static constexpr const char* CLEAR_TO_EOL = "\033[K"; + // Add newline + fprintf(stderr, "\n"); + } + + // Finalize an UPDATE line by adding a newline + static void finalizeProgress(LogLevel level) { + finalizeUpdate(level); + instance().progress_line_active = false; + } + +private: + // ANSI terminal control sequences + static constexpr const char *CLEAR_TO_EOL = "\033[K"; - static constexpr int PADDING_SIZE = 80; + static constexpr int PADDING_SIZE = 80; - static bool shouldLog(LogLevel level); - static void clearLine(); - static void finalizeProgressIfActive(); - static void reprintProgressIfActive(); + static bool shouldLog(LogLevel level); + static void clearLine(); + static void finalizeProgressIfActive(); + static void reprintProgressIfActive(); }; } // namespace openzl::tools::logger diff --git a/tools/training/clustering/trainers/bottom_up_trainer.cpp b/tools/training/clustering/trainers/bottom_up_trainer.cpp index 7f62237d..424f92fa 100644 --- a/tools/training/clustering/trainers/bottom_up_trainer.cpp +++ b/tools/training/clustering/trainers/bottom_up_trainer.cpp @@ -2,174 +2,144 @@ #include -#include "tools/training/clustering/trainers/bottom_up_trainer.h" #include "tools/logger/Logger.h" #include "tools/training/clustering/clustering_config_builder.h" +#include "tools/training/clustering/trainers/bottom_up_trainer.h" namespace openzl::training { using namespace openzl::tools::logger; ClusteringConfigBuilder BottomUpTrainer::buildTrainedFullSplitConfig( - const CompressionUtils& cUtils, - const ColumnMetadata& metadata, - const std::map, size_t>& - typeToDefaultSuccessorIdxMap) -{ - auto config = ClusteringConfigBuilder::buildFullSplitConfig( - metadata, - typeToDefaultSuccessorIdxMap, - cUtils.getTypeToClusteringCodecIdxsMap()); + const CompressionUtils &cUtils, const ColumnMetadata &metadata, + const std::map, size_t> + &typeToDefaultSuccessorIdxMap) { + auto config = ClusteringConfigBuilder::buildFullSplitConfig( + metadata, typeToDefaultSuccessorIdxMap, + cUtils.getTypeToClusteringCodecIdxsMap()); - // Pick successor and clustering codec for each cluster in full split - std::vector> futures; - futures.reserve(config.clusters().size()); - const auto& clusters = config.clusters(); - for (const auto& cluster : clusters) { - const auto& tags = cluster.memberTags; - auto type = cluster.typeSuccessor.type; - auto width = cluster.typeSuccessor.eltWidth; - auto task = [&cUtils, &metadata, tags, type, width]() { - auto clusterInfo = - cUtils.getBestClusterInfo(tags, type, width, metadata); - return clusterInfo; - }; - futures.emplace_back(this->threadPool_->run(task)); - } - size_t clusterIdx = 0; - for (auto& future : futures) { - auto clusterInfo = future.get(); - config.setClusterSuccessor(clusterIdx, clusterInfo.successorIdx); - config.setClusteringCodec(clusterIdx, clusterInfo.clusteringCodecIdx); - clusterIdx++; - } - return config; + // Pick successor and clustering codec for each cluster in full split + std::vector> futures; + futures.reserve(config.clusters().size()); + const auto &clusters = config.clusters(); + for (const auto &cluster : clusters) { + const auto &tags = cluster.memberTags; + auto type = cluster.typeSuccessor.type; + auto width = cluster.typeSuccessor.eltWidth; + auto task = [&cUtils, &metadata, tags, type, width]() { + auto clusterInfo = cUtils.getBestClusterInfo(tags, type, width, metadata); + return clusterInfo; + }; + futures.emplace_back(this->threadPool_->run(task)); + } + size_t clusterIdx = 0; + for (auto &future : futures) { + auto clusterInfo = future.get(); + config.setClusterSuccessor(clusterIdx, clusterInfo.successorIdx); + config.setClusteringCodec(clusterIdx, clusterInfo.clusteringCodecIdx); + clusterIdx++; + } + return config; } ClusteringConfigBuilder BottomUpTrainer::buildTrainedConfigAddInputToCluster( - const CompressionUtils& cUtils, - const ColumnMetadata& metadata, - const ClusteringConfigBuilder& config, - int tag, - ZL_Type type, - size_t eltWidth, - size_t clusterIdx) -{ - auto candidate = config.buildConfigAddInputToCluster( - tag, type, eltWidth, clusterIdx); - auto clusterInfo = cUtils.getBestClusterInfo( - candidate.clusters()[clusterIdx].memberTags, - type, - eltWidth, - metadata); - candidate.setClusterSuccessor(clusterIdx, clusterInfo.successorIdx); - candidate.setClusteringCodec(clusterIdx, clusterInfo.clusteringCodecIdx); - return candidate; + const CompressionUtils &cUtils, const ColumnMetadata &metadata, + const ClusteringConfigBuilder &config, int tag, ZL_Type type, + size_t eltWidth, size_t clusterIdx) { + auto candidate = + config.buildConfigAddInputToCluster(tag, type, eltWidth, clusterIdx); + auto clusterInfo = cUtils.getBestClusterInfo( + candidate.clusters()[clusterIdx].memberTags, type, eltWidth, metadata); + candidate.setClusterSuccessor(clusterIdx, clusterInfo.successorIdx); + candidate.setClusteringCodec(clusterIdx, clusterInfo.clusteringCodecIdx); + return candidate; } ClusteringConfigBuilder BottomUpTrainer::getTrainedClusteringConfig( - const ZL_Compressor* compressor, - const std::vector& samples, - const std::vector& successors, - const std::vector& clusteringCodecs, - const std::map, size_t>& - typeToDefaultSuccessorIdxMap) -{ - auto start = std::chrono::high_resolution_clock::now(); - auto cUtils = CompressionUtils( - compressor, - samples, - successors, - clusteringCodecs, - this->threadPool_); - auto metadata = cUtils.aggregateInputMetadata(); - auto config = buildTrainedFullSplitConfig( - cUtils, metadata, typeToDefaultSuccessorIdxMap); - Logger::log_c( - INFO, - "Created trained full split config with %zu inputs", - metadata.size()); + const ZL_Compressor *compressor, const std::vector &samples, + const std::vector &successors, + const std::vector &clusteringCodecs, + const std::map, size_t> + &typeToDefaultSuccessorIdxMap) { + auto start = std::chrono::high_resolution_clock::now(); + auto cUtils = CompressionUtils(compressor, samples, successors, + clusteringCodecs, this->threadPool_); + auto metadata = cUtils.aggregateInputMetadata(); + auto config = buildTrainedFullSplitConfig(cUtils, metadata, + typeToDefaultSuccessorIdxMap); + Logger::log_c(INFO, "Created trained full split config with %zu inputs", + metadata.size()); - // Build clusters up iteratively from full split state - size_t nbClusters = 0; - auto bestCost = cUtils.tryCompress(config.build()).get(); - for (const auto& data : metadata) { - if (nbClusters == 0) { - // No cluster has been built yet for the first tag. Track it in the - // set of clusters - nbClusters++; - continue; - } - auto now = std::chrono::high_resolution_clock::now(); - auto duration = - std::chrono::duration_cast(now - start); - if (maxTimeSecs_.has_value() - && (size_t)duration.count() > maxTimeSecs_.value()) { - Logger::log_c( - INFO, - "Stopping training early after %zu s. Exceeded max time of %zu s.", - duration.count(), - maxTimeSecs_.value()); - return config; - } - auto tag = data.tag; - auto typeInfo = std::make_pair(data.type, data.width); - bool hasImprovement = false; - // Try add tag to existing clusters only if it hasn't been visited - std::vector> candidateFutures; - candidateFutures.reserve(nbClusters); - for (size_t i = 0; i < nbClusters; i++) { - if (!config.typeIsCompatibleWithClusterIdx( - data.type, data.width, (int)i)) { - continue; - } - auto task = - [this, &cUtils, &metadata, &config, tag, typeInfo, i]() { - return buildTrainedConfigAddInputToCluster( - cUtils, - metadata, - config, - tag, - typeInfo.first, - typeInfo.second, - i); - }; - candidateFutures.emplace_back(this->threadPool_->run(task)); - } - std::vector> costFutures; - std::vector candidates; - costFutures.reserve(candidateFutures.size()); - candidates.reserve(candidateFutures.size()); - for (auto& candidateFuture : candidateFutures) { - auto res = candidateFuture.get(); - costFutures.emplace_back(cUtils.tryCompress(res.build())); - candidates.emplace_back(std::move(res)); - } - for (size_t i = 0; i < candidates.size(); i++) { - assert(i < costFutures.size()); - auto cost = costFutures[i].get(); - if (cost < bestCost) { - bestCost = cost; - config = std::move(candidates[i]); - hasImprovement = true; - } - } + // Build clusters up iteratively from full split state + size_t nbClusters = 0; + auto bestCost = cUtils.tryCompress(config.build()).get(); + for (const auto &data : metadata) { + if (nbClusters == 0) { + // No cluster has been built yet for the first tag. Track it in the + // set of clusters + nbClusters++; + continue; + } + auto now = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(now - start); + if (maxTimeSecs_.has_value() && + (size_t)duration.count() > maxTimeSecs_.value()) { + Logger::log_c( + INFO, + "Stopping training early after %zu s. Exceeded max time of %zu s.", + duration.count(), maxTimeSecs_.value()); + return config; + } + auto tag = data.tag; + auto typeInfo = std::make_pair(data.type, data.width); + bool hasImprovement = false; + // Try add tag to existing clusters only if it hasn't been visited + std::vector> candidateFutures; + candidateFutures.reserve(nbClusters); + for (size_t i = 0; i < nbClusters; i++) { + if (!config.typeIsCompatibleWithClusterIdx(data.type, data.width, + (int)i)) { + continue; + } + auto task = [this, &cUtils, &metadata, &config, tag, typeInfo, i]() { + return buildTrainedConfigAddInputToCluster( + cUtils, metadata, config, tag, typeInfo.first, typeInfo.second, i); + }; + candidateFutures.emplace_back(this->threadPool_->run(task)); + } + std::vector> costFutures; + std::vector candidates; + costFutures.reserve(candidateFutures.size()); + candidates.reserve(candidateFutures.size()); + for (auto &candidateFuture : candidateFutures) { + auto res = candidateFuture.get(); + costFutures.emplace_back(cUtils.tryCompress(res.build())); + candidates.emplace_back(std::move(res)); + } + for (size_t i = 0; i < candidates.size(); i++) { + assert(i < costFutures.size()); + auto cost = costFutures[i].get(); + if (cost < bestCost) { + bestCost = cost; + config = std::move(candidates[i]); + hasImprovement = true; + } + } - if (hasImprovement) { - // A candidate has been found that adds to a tracked cluster. - // Therefore we do not increment the number of clusters. - Logger::log_c(VERBOSE1, "New cost: %zu", bestCost.compressedSize); - } else { - // Increment the number of clusters to track the new tag as its own - // cluster - nbClusters++; - Logger::log_c(VERBOSE1, "No improvement found using tag: %zu", tag); - } + if (hasImprovement) { + // A candidate has been found that adds to a tracked cluster. + // Therefore we do not increment the number of clusters. + Logger::log_c(VERBOSE1, "New cost: %zu", bestCost.compressedSize); + } else { + // Increment the number of clusters to track the new tag as its own + // cluster + nbClusters++; + Logger::log_c(VERBOSE1, "No improvement found using tag: %zu", tag); } - Logger::log_c( - VERBOSE1, - "Final config found with cost: ", - bestCost.compressedSize); - return config; + } + Logger::log_c(VERBOSE1, + "Final config found with cost: ", bestCost.compressedSize); + return config; } } // namespace openzl::training