From e88ad81c5eeb10166b93a35a025324603a5bc6d9 Mon Sep 17 00:00:00 2001 From: Askold Ilvento Date: Fri, 30 Jan 2026 23:35:07 +0300 Subject: [PATCH] allow adjoint parameters --- include/argparse/argparse.hpp | 32 ++++++++++++++----- tests/tests.cpp | 60 ++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 9 deletions(-) mode change 100755 => 100644 include/argparse/argparse.hpp mode change 100755 => 100644 tests/tests.cpp diff --git a/include/argparse/argparse.hpp b/include/argparse/argparse.hpp old mode 100755 new mode 100644 index 8d7dfbb..bcae6b1 --- a/include/argparse/argparse.hpp +++ b/include/argparse/argparse.hpp @@ -42,6 +42,7 @@ #include // for vector #include // for std::wstring_convert #include // for std::wstring_convert +#include // for short flag names look-up table // for enum_entries #if __has_include() @@ -138,7 +139,7 @@ namespace argparse { template<> inline short get(const std::string &v) { return std::stoi(v); } template<> inline long get(const std::string &v) { return std::stol(v); } template<> inline long long get(const std::string &v) { return std::stol(v); } - template<> inline bool get(const std::string &v) { return to_lower(v) == "true" || v == "1"; } + template<> inline bool get(const std::string &v) { return to_lower(v) == "true" || v == "1" || v == "on"; } template<> inline float get(const std::string &v) { return std::stof(v); } template<> inline double get(const std::string &v) { return std::stod(v); } template<> inline unsigned char get(const std::string &v) { return get(v); } @@ -350,6 +351,7 @@ namespace argparse { std::map> kwarg_entries; std::vector> arg_entries; std::map> subcommand_entries; + std::bitset<256> short_explicit_names; bool has_options() { return std::find_if(all_entries.begin(), all_entries.end(), [](auto e) { return e->type != Entry::ARG; }) != all_entries.end(); }; @@ -396,6 +398,9 @@ namespace argparse { all_entries.emplace_back(entry); for (const std::string &k : entry->keys_) { kwarg_entries[k] = entry; + if (k.size() == 1 && !implicit_value) { + short_explicit_names.set(k[0]); + } } return *entry; } @@ -493,6 +498,15 @@ namespace argparse { auto is_value = [&](const size_t &i) -> bool { return params.size() > i && (params[i][0] != '-' || (params[i].size() > 1 && std::isdigit(params[i][1]))); // check for number to not accidentally mark negative numbers as non-parameter }; + + auto parse_multi_argument = [&](size_t &i, Entry& entry, std::string value) { + if (entry._is_multi_argument) { + while (is_value(i + 1)) + value += "," + params[++i]; + } + entry._convert(value); + }; + auto parse_param = [&](size_t &i, const std::string &key, const bool is_short, const std::optional &equal_value=std::nullopt) { auto itt = kwarg_entries.find(key); if (itt != kwarg_entries.end()) { @@ -503,12 +517,7 @@ namespace argparse { entry->_convert(*entry->implicit_value_); } else if (!is_short) { // short values are not allowed to look ahead for the next parameter if (is_value(i + 1)) { - std::string value = params[++i]; - if (entry->_is_multi_argument) { - while (is_value(i + 1)) - value += "," + params[++i]; - } - entry->_convert(value); + parse_multi_argument(i, *entry, params[++i]); } else if (entry->_is_multi_argument) { entry->_convert(""); // for multiargument parameters, return an empty vector when not passing any more values } else { @@ -524,6 +533,7 @@ namespace argparse { cerr << "unrecognised commandline argument : " << key << endl; } }; + auto add_param = [&](size_t &i, const size_t &start) { size_t eq_idx = params[i].find('='); // check if value was passed using the '=' sign if (eq_idx != std::string::npos) { // key/value from = notation @@ -545,10 +555,15 @@ namespace argparse { const size_t j_end = std::min(params[i].size(), params[i].find('=')) - 1; for (size_t j = 1; j < j_end; j++) { // add possible other flags const std::string key = std::string(1, params[i][j]); + if (short_explicit_names[params[i][j]]) { + parse_multi_argument(i, *kwarg_entries[key], params[i].substr(j + 1)); + goto skip; + } parse_param(i, key, true); } add_param(i, j_end); - } + skip:; + } } else { arguments_flat.emplace_back(params[i]); } @@ -560,6 +575,7 @@ namespace argparse { if (arg_i < arguments_flat.size()) arg_entries[arg_i]->_convert(arguments_flat[arg_i]); } + size_t arg_j = 1; for (size_t j_end = arg_entries.size() - arg_i; arg_j <= j_end; arg_j++) { // iterate from back to front, to ensure non-multi-arguments in the front and back are given preference size_t flat_idx = arguments_flat.size() - arg_j; diff --git a/tests/tests.cpp b/tests/tests.cpp old mode 100755 new mode 100644 index a3c2071..e6c7b8d --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -186,32 +186,42 @@ void TEST_THROW() { { std::string command = "argparse_test"; const auto &[argc, argv] = get_argc_argv(command); + bool is_thrown{}; try { auto args = argparse::parse(argc, argv, true); } catch (const std::runtime_error &e) { assert(std::string(e.what()) == "Argument missing: arg_0 (Source path)"); + is_thrown = true; } + assert(is_thrown && "Must throw"); } { std::string command = "argparse_test source_path -k=notanobumber"; const auto &[argc, argv] = get_argc_argv(command); + bool is_thrown{}; try { auto args = argparse::parse(argc, argv, true); } catch (const std::runtime_error &e) { assert(std::string(e.what()) == "Invalid argument, could not convert \"notanobumber\" for -k ()"); + is_thrown = true; } + assert(is_thrown && "Must throw"); } { std::string command = "argparse_test source_path source_path"; const auto &[argc, argv] = get_argc_argv(command); + bool is_thrown{}; try { auto args = argparse::parse(argc, argv, true); } catch (const std::runtime_error &e) { assert(std::string(e.what()) == "Argument missing: -a,--alpha (required alpha value)"); + is_thrown = true; } + assert(is_thrown && "Must throw"); } + } void TEST_SUBCOMMANDS() { @@ -320,6 +330,53 @@ void TEST_OPTIONAL_POINTER() { } } +void TEST_ADJOINT() { + struct Args : public argparse::Args { + int &number1 = kwarg("o,number1", "A optional number").set_default(10); + int &number2 = kwarg("n,number2", "A mandatory number"); + std::vector &numbers = kwarg("m,numbers", "Multiple numbers").multi_argument().set_default(std::vector()); + bool &flag1 = flag("f,flag", "A flag"); + bool &flag2 = flag("flag2", "A flag"); + bool &flag3 = flag("g", "A flag"); + }; + + bool is_thrown{}; + try { + std::string command{"argparse_test -n10 -flag2"}; + const auto &[argc, argv] = get_argc_argv(command); + auto args = argparse::parse(argc, argv, true); + } catch (const std::runtime_error &e) { + std::cout << std::string(e.what()) << '\n'; + assert(std::string(e.what()) == "unrecognised commandline argument : l"); + is_thrown = true; + } + assert(is_thrown && "Must throw"); + + { + Args args = test_args("argparse_test -n10 --flag2"); + assert(args.number2 == 10); + assert(args.flag2); + } + + { + Args args = test_args("argparse_test -fn10 -go5"); + assert(args.number2 == 10); + assert(args.number1 == 5); + assert(args.flag1); + assert(!args.flag2); + assert(args.flag3); + } + + { + Args args = test_args("argparse_test -fgm1 2 3 -n=12"); + assert(args.flag1); + assert(!args.flag2); + assert(args.flag3); + assert(args.number2 == 12); + assert(args.numbers == std::vector({1, 2, 3})); + } +} + int main(int argc, char* argv[]) { TEST_ALL(); TEST_MULTI(); @@ -331,11 +388,12 @@ int main(int argc, char* argv[]) { std::cout << "Magic Enum not installed in this system, therefore native enum support disabled" << std::endl; #endif - TEST_SUBCOMMANDS(); + TEST_SUBCOMMANDS(); TEST_SHORT_GROUP(); TEST_EQUALS(); TEST_EMPTY_MULTI(); TEST_OPTIONAL_POINTER(); + TEST_ADJOINT(); std::cout << "finished all tests" << std::endl; return 0;