diff --git a/src/cli.cpp b/src/cli.cpp index 750ab30..027a4db 100644 --- a/src/cli.cpp +++ b/src/cli.cpp @@ -1,18 +1,13 @@ +#include "cli.h" + #include -#include "cli.h" #include "key.h" -void cli::help() { - std::cout << "Welcome to RSA-Encryptor" << std::endl; -} +void cli::help() { std::cout << "Welcome to RSA-Encryptor" << std::endl; } -void cli::keyManager::create() { - key::createRSAKey(); -} +void cli::keyManager::create() { key::createRSAKey(); } -void cli::keyManager::list() { -} +void cli::keyManager::list() {} -void cli::keyManager::print(const std::string& name, bool publicKey, bool privateKey) { -} +void cli::keyManager::print(const std::string& name, bool publicKey, bool privateKey) {} diff --git a/src/cli.h b/src/cli.h index 9b1ed5b..f56e569 100644 --- a/src/cli.h +++ b/src/cli.h @@ -1,14 +1,16 @@ #ifndef CLI_H #define CLI_H +#include + namespace cli { - void help(); - - namespace keyManager { - void create(); - void list(); - void print(const std::string& name, bool publicKey, bool privateKey); - } -} +void help(); + +namespace keyManager { +void create(); +void list(); +void print(const std::string& name, bool publicKey, bool privateKey); +} // namespace keyManager +} // namespace cli #endif \ No newline at end of file diff --git a/src/encryption.cpp b/src/encryption.cpp index 6016c19..6ac6e58 100644 --- a/src/encryption.cpp +++ b/src/encryption.cpp @@ -1,7 +1,7 @@ -#include - #include "encryption.h" +#include + std::int8_t encrypt(std::uint8_t data, unsigned long int e, unsigned long N) { return std::pow(data, e); } diff --git a/src/encryption.h b/src/encryption.h index d4b13b9..111cf26 100644 --- a/src/encryption.h +++ b/src/encryption.h @@ -1,13 +1,12 @@ #ifndef ENCRYPTION_H #define ENCRYPTION_H -#include #include +#include -class encryption -{ -private: -public: +class encryption { + private: + public: std::int8_t encrypt(std::uint8_t data, unsigned long int e, unsigned long N); }; diff --git a/src/key.cpp b/src/key.cpp index 44b4b23..0c0e6ec 100644 --- a/src/key.cpp +++ b/src/key.cpp @@ -1,9 +1,9 @@ +#include "key.h" + #include #include #include -#include "key.h" - std::filesystem::path key::keysPath() { // Navigate to the root directory of the project std::filesystem::path programRootPath = std::filesystem::current_path().parent_path(); @@ -13,26 +13,36 @@ std::filesystem::path key::keysPath() { int key::keyExists(std::string name) { // TODO fix function not case sensitive - /* 0: key doesn't exist - * 1: only publicKey exists - * 2: only privateKey exists - * 3: both keys exists */ + // 0: key doesn't exist + // 1: only publicKey exists + // 2: only privateKey exists + // 3: both keys exists int status = NONE; std::filesystem::path keysFolder = keysPath(); // go over every file in the KEY_FOLDER - for (const auto& entry : std::filesystem::directory_iterator(keysFolder)) { + for (const auto &entry : std::filesystem::directory_iterator(keysFolder)) { // check if the name matches if (entry.path().stem().string() == name) { // check if it's a public key if (entry.path().extension() == ".pub") { // break the loop if both files are found - if (status == NONE) status = PUBLIC; else {status = BOTH; break;}; + if (status == NONE) + status = PUBLIC; + else { + status = BOTH; + break; + }; } // check if it's a private key else if (entry.path().has_extension() == false) { - if (status == NONE) status = PRIVATE; else {status = BOTH; break;}; + if (status == NONE) + status = PRIVATE; + else { + status = BOTH; + break; + }; } } } @@ -40,12 +50,12 @@ int key::keyExists(std::string name) { return status; } -int key::writeKey(const std::string& name, std::vector *data, const bool isPublic) { +int key::writeKey(const std::string &name, std::vector *data, const bool isPublic) { if (data == nullptr) return -1; std::filesystem::path keysFolder = keysPath(); - std::filesystem::path keyFile = keysFolder / (name + (isPublic? ".pub":"")); + std::filesystem::path keyFile = keysFolder / (name + (isPublic ? ".pub" : "")); // load file in memory std::ofstream outFile(keyFile); @@ -53,11 +63,11 @@ int key::writeKey(const std::string& name, std::vector *data, const boo std::string base64Data = base64Encode(*data); if (outFile.is_open()) { - outFile << "-----BEGIN RSA " << (isPublic? "PUBLIC":"PRIVATE") << " KEY-----\n"; + outFile << "-----BEGIN RSA " << (isPublic ? "PUBLIC" : "PRIVATE") << " KEY-----\n"; for (int i = 0; i < base64Data.size(); i += 64) { outFile << base64Data.substr(i, 64) << "\n"; } - outFile << "-----END RSA " << (isPublic? "PUBLIC":"PRIVATE") << " KEY-----\n"; + outFile << "-----END RSA " << (isPublic ? "PUBLIC" : "PRIVATE") << " KEY-----\n"; outFile.close(); } else { std::cerr << "Could not write to file: " << keyFile << std::endl; @@ -70,7 +80,7 @@ int key::writeKey(const std::string& name, std::vector *data, const boo int key::readKey(const std::string &name, std::vector *data, bool isPublic) { std::filesystem::path keysFolder = keysPath(); - std::filesystem::path keyFile = keysFolder / (name + (isPublic? ".pub":"")); + std::filesystem::path keyFile = keysFolder / (name + (isPublic ? ".pub" : "")); // load file in memory std::ifstream inFile(keyFile); @@ -83,8 +93,7 @@ int key::readKey(const std::string &name, std::vector *data, bool isPub std::string currentLine; std::string keyString; - while (getline (inFile, currentLine)) { - + while (getline(inFile, currentLine)) { if (currentLine.empty()) continue; if (currentLine.at(0) == '-') continue; @@ -111,7 +120,6 @@ std::string key::base64Encode(const std::vector &data) { int index = 0; while (index < data.size()) { - uint32_t dataSegment = 0; // take 3 numbers from the vector @@ -143,7 +151,7 @@ std::string key::base64Encode(const std::vector &data) { std::vector key::base64Decode(std::string data) { // TODO the decoding might not work if the data size is not dividable by 4 if (data.size() % 4 != 0) std::cerr << "Decoding data size not dividable by 4" << std::endl; - + std::vector result; while (!data.empty()) { @@ -162,7 +170,8 @@ std::vector key::base64Decode(std::string data) { // check for valid char code if (number > 0b111111) { - std::cerr << "trying to Decode not valid char: " << letterSegment.at(i) << std::endl; + std::cerr << "trying to Decode not valid char: " << letterSegment.at(i) + << std::endl; continue; } dataSegment += number << i * 6; @@ -187,24 +196,14 @@ uint8_t key::getBase64Index(char letter) { } int key::createKey(std::vector *keyPublic, std::vector *keyPrivate) { - *keyPublic = { - 23, 87, 45, 190, 12, 78, 34, 210, 56, 89, - 123, 67, 90, 150, 32, 76, 54, 200, 11, 99, - 101, 145, 67, 189, 43, 88, 29, 176, 58, 92, - 111, 134, 78, 201, 15, 84, 39, 220, 66, 97, - 105, 142, 71, 185, 49, 81, 27, 170, 53, 95, - 41 - }; - *keyPrivate = { - 34, 78, 123, 56, 89, 210, 45, 190, 12, 87, - 67, 150, 32, 76, 54, 200, 11, 99, 101, 145, - 67, 189, 43, 88, 29, 176, 58, 92, 111, 134, - 78, 201, 15, 84, 39, 220, 66, 97, 105, 142, - 71, 185, 49, 81, 27, 170, 53, 95, 102, 147, - 68, 191, 44, 85, 30, 177, 59, 93, 112, 135, - 79, 202, 16, 85, 40, 221, 67, 98, 23, 87, - 91, 65 - }; + *keyPublic = {23, 87, 45, 190, 12, 78, 34, 210, 56, 89, 123, 67, 90, 150, 32, 76, 54, + 200, 11, 99, 101, 145, 67, 189, 43, 88, 29, 176, 58, 92, 111, 134, 78, 201, + 15, 84, 39, 220, 66, 97, 105, 142, 71, 185, 49, 81, 27, 170, 53, 95, 41}; + *keyPrivate = {34, 78, 123, 56, 89, 210, 45, 190, 12, 87, 67, 150, 32, 76, 54, + 200, 11, 99, 101, 145, 67, 189, 43, 88, 29, 176, 58, 92, 111, 134, + 78, 201, 15, 84, 39, 220, 66, 97, 105, 142, 71, 185, 49, 81, 27, + 170, 53, 95, 102, 147, 68, 191, 44, 85, 30, 177, 59, 93, 112, 135, + 79, 202, 16, 85, 40, 221, 67, 98, 23, 87, 91, 65}; return 1; } @@ -228,8 +227,8 @@ void key::createRSAKey() { return; } - auto* keyPublic = new std::vector(); - auto* keyPrivate = new std::vector(); + auto *keyPublic = new std::vector(); + auto *keyPrivate = new std::vector(); createKey(keyPublic, keyPrivate); @@ -241,14 +240,14 @@ void key::createRSAKey() { } } -std::vector * key::getPrivateKey(std::string &name) { - std::vector* data = new std::vector; +std::vector *key::getPrivateKey(std::string &name) { + std::vector *data = new std::vector; readKey(name, data, false); return data; } -std::vector * key::getPublicKey(std::string &name) { - std::vector* data = new std::vector; +std::vector *key::getPublicKey(std::string &name) { + std::vector *data = new std::vector; readKey(name, data, true); return data; } diff --git a/src/key.h b/src/key.h index abdd808..005a59d 100644 --- a/src/key.h +++ b/src/key.h @@ -1,24 +1,19 @@ #ifndef KEY_H #define KEY_H -#include #include +#include #include #include #define KEY_FOLDER "rsa-keys" -enum { - NONE, - PUBLIC, - PRIVATE, - BOTH -}; +enum { NONE, PUBLIC, PRIVATE, BOTH }; const char base64Chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; class key { -private: + private: static std::filesystem::path keysPath(); static int keyExists(std::string name); static void generatePublicFromPrivate(); @@ -33,11 +28,11 @@ class key { static int createKey(std::vector *keyPublic, std::vector *keyPrivate); -public: + public: static void createRSAKey(); - static std::vector* getPrivateKey(std::string &name); - static std::vector* getPublicKey(std::string &name); + static std::vector *getPrivateKey(std::string &name); + static std::vector *getPublicKey(std::string &name); }; #endif \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 15e2ebd..dd0d359 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,58 +1,56 @@ -#include -#include -#include #include +#include #include +#include +#include -#include "utility.h" #include "cli.h" #include "key.h" +#include "utility.h" #include "vec/operations.h" int main(int argc, char* argv[]) { - - struct { - bool privateKey = false; - bool publicKey = false; - } arguments; - - // Check if there are additional arguments - if (argc > 1) { - - // Check for additional arguments - for (int i = 1; argc > i; i++) { - // Check for a public key - if (!strcmp(argv[i], "-p") || !strcmp(argv[i], "--public")) { - arguments.publicKey = true; - } - - // Check for a private key - if (!strcmp(argv[i], "-P") || !strcmp(argv[i], "--private")) { - arguments.privateKey = true; - } - } - - std::string feature = argv[1]; - - if (feature == "key" && argv[2] != nullptr) { - std::string subFeature = argv[2]; - - if (subFeature == "create") { - cli::keyManager::create(); - } else if (subFeature == "list") { - cli::keyManager::list(); - } else if (subFeature == "print" && argv[3] != nullptr) { - cli::keyManager::print(argv[3], arguments.publicKey, arguments.privateKey); - } - - } else if (feature == "help") { - cli::help(); - } else { - cli::help(); - } - } else { - cli::help(); - } - - return 0; + struct { + bool privateKey = false; + bool publicKey = false; + } arguments; + + // Check if there are additional arguments + if (argc > 1) { + // Check for additional arguments + for (int i = 1; argc > i; i++) { + // Check for a public key + if (!strcmp(argv[i], "-p") || !strcmp(argv[i], "--public")) { + arguments.publicKey = true; + } + + // Check for a private key + if (!strcmp(argv[i], "-P") || !strcmp(argv[i], "--private")) { + arguments.privateKey = true; + } + } + + std::string feature = argv[1]; + + if (feature == "key" && argv[2] != nullptr) { + std::string subFeature = argv[2]; + + if (subFeature == "create") { + cli::keyManager::create(); + } else if (subFeature == "list") { + cli::keyManager::list(); + } else if (subFeature == "print" && argv[3] != nullptr) { + cli::keyManager::print(argv[3], arguments.publicKey, arguments.privateKey); + } + + } else if (feature == "help") { + cli::help(); + } else { + cli::help(); + } + } else { + cli::help(); + } + + return 0; } diff --git a/src/utility.cpp b/src/utility.cpp index 688a41b..31109dd 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -7,18 +7,13 @@ int utility::gcd(int a, int b) { return gcd(b % a, a); } -int utility::phi(int a, int b) { - return ((a - 1) * (b - 1)); -} +int utility::phi(int a, int b) { return ((a - 1) * (b - 1)); } -bool utility::checkForPrime(int number) -{ +bool utility::checkForPrime(int number) { int counter = 0; if (number <= 1) { return false; - } - else { - + } else { // Check how many numbers divide n in // Range 2 to sqrt(n) for (int i = 2; i * i <= number; i++) { @@ -30,8 +25,7 @@ bool utility::checkForPrime(int number) // If counter is greater than 0 then n is prime if (counter > 0) { return false; - } - else { + } else { return true; } } diff --git a/src/utility.h b/src/utility.h index 29dd202..5cae06c 100644 --- a/src/utility.h +++ b/src/utility.h @@ -3,11 +3,9 @@ #include -class utility -{ -private: - -public: +class utility { + private: + public: int gcd(int a, int b); int phi(int a, int b); bool checkForPrime(int number); diff --git a/src/vec/operations.h b/src/vec/operations.h index c76a995..60308fd 100644 --- a/src/vec/operations.h +++ b/src/vec/operations.h @@ -1,362 +1,333 @@ #ifndef VEC_OPERATIONS_H #define VEC_OPERATIONS_H +#include #include #include -#include namespace operations { - [[nodiscard]] std::uint64_t getStartBitIndex (const std::vector &a) { - std::uint64_t index = 0; - std::uint64_t bitsSince1 = 0; - - for (std::uint8_t number: a) { - // Skip 8 bits if all zero - if (number == 0) { - bitsSince1 += 8; - continue; - } - // Go over every bit in number - for (std::uint64_t i = 0; i < sizeof(std::uint8_t) * 8; i++) { - // Select the current bit using a bitmask - if ((number & 0b1 << i) != 0) { - // Increment the index by the bits since the last increment - index += bitsSince1 + 1; - bitsSince1 = 0; - } else { - bitsSince1++; - } +[[nodiscard]] std::uint64_t getStartBitIndex(const std::vector &a) { + std::uint64_t index = 0; + std::uint64_t bitsSince1 = 0; + + for (std::uint8_t number : a) { + // Skip 8 bits if all zero + if (number == 0) { + bitsSince1 += 8; + continue; + } + // Go over every bit in number + for (std::uint64_t i = 0; i < sizeof(std::uint8_t) * 8; i++) { + // Select the current bit using a bitmask + if ((number & 0b1 << i) != 0) { + // Increment the index by the bits since the last increment + index += bitsSince1 + 1; + bitsSince1 = 0; + } else { + bitsSince1++; } } - - // Returns position of the first bit needed for the number - return index - 1; } - /* This function picks one bit from pickNumber and places it at the least significant - * position of number. */ - [[nodiscard]] std::vector addBitFromNumber (const std::vector &number, const std::vector &pickNumber, std::uint32_t index) { - std::vector result; - - // Check if the index is valid - if (index > getStartBitIndex(pickNumber)) { - return {0}; - } + // Returns position of the first bit needed for the number + return index - 1; +} - // Check if the bit at the given index is set - bool mostSignificantBit = pickNumber[index / 8] & 0b1 << index % 8; +// This function picks one bit from pickNumber and places it at the least significant position of +// number. +[[nodiscard]] std::vector addBitFromNumber( + const std::vector &number, const std::vector &pickNumber, + std::uint32_t index) { + std::vector result; - for (std::uint8_t currentByte : number) { - result.push_back((currentByte << 1) + mostSignificantBit); + // Check if the index is valid + if (index > getStartBitIndex(pickNumber)) { + return {0}; + } - // Check if the most significant bit of the current byte is set - mostSignificantBit = currentByte & 0b1000000 == 0b10000000; - } + // Check if the bit at the given index is set + bool mostSignificantBit = pickNumber[index / 8] & 0b1 << index % 8; - /* In case the length of the vector and the actual number length matches, the number - * increases by one vector element with value 1 */ - if (mostSignificantBit) result.push_back(0b1); + for (std::uint8_t currentByte : number) { + result.push_back((currentByte << 1) + mostSignificantBit); - return result; + // Check if the most significant bit of the current byte is set + mostSignificantBit = currentByte & 0b1000000 == 0b10000000; } - [[nodiscard]] bool isZero (const std::vector &a) { - for (std::uint8_t number: a) { - if (number != 0) return false; - } + // In case the length of the vector and the actual number length matches, the numberincreases by + // one vector element with value 1 + if (mostSignificantBit) result.push_back(0b1); - return true; + return result; +} + +[[nodiscard]] bool isZero(const std::vector &a) { + for (std::uint8_t number : a) { + if (number != 0) return false; } - [[nodiscard]] bool isEqual ( - const std::vector &a, - const std::vector &b) noexcept - { - const std::uint32_t aSize = a.size(); - const std::uint32_t bSize = b.size(); - const std::uint64_t iterations = aSize > bSize ? aSize : bSize; - - // In case both numbers are the same length these variables could point ot the same vector - const std::uint32_t shortest = aSize < bSize ? aSize : bSize; - const std::vector longest = aSize > bSize ? a : b; - - for (std::uint64_t i = 0; i < iterations; i++) { - if (i >= shortest) { - if (longest[i] != 0) return false; - } else { - if (a[i] != b[i]) return false; - } - } + return true; +} - return true; +[[nodiscard]] bool isEqual(const std::vector &a, + const std::vector &b) noexcept { + const std::uint32_t aSize = a.size(); + const std::uint32_t bSize = b.size(); + const std::uint64_t iterations = aSize > bSize ? aSize : bSize; + + // In case both numbers are the same length these variables could point ot the same vector + const std::uint32_t shortest = aSize < bSize ? aSize : bSize; + const std::vector longest = aSize > bSize ? a : b; + + for (std::uint64_t i = 0; i < iterations; i++) { + if (i >= shortest) { + if (longest[i] != 0) return false; + } else { + if (a[i] != b[i]) return false; + } } - [[nodiscard]] bool isBigger ( - const std::vector &a, - const std::vector &b) noexcept - { - const std::uint32_t aSize = a.size(); - const std::uint32_t bSize = b.size(); - const std::uint64_t iterations = aSize > bSize ? aSize : bSize; + return true; +} - // We loop reverse throw all the vectors to catch leading zeros - for (std::int64_t i = iterations - 1; i >= 0; i--) { +[[nodiscard]] bool isBigger(const std::vector &a, + const std::vector &b) noexcept { + const std::uint32_t aSize = a.size(); + const std::uint32_t bSize = b.size(); + const std::uint64_t iterations = aSize > bSize ? aSize : bSize; - const std::uint8_t aValue = (i < aSize) ? a[i] : 0; - const std::uint8_t bValue = (i < bSize) ? b[i] : 0; + // We loop reverse throw all the vectors to catch leading zeros + for (std::int64_t i = iterations - 1; i >= 0; i--) { + const std::uint8_t aValue = (i < aSize) ? a[i] : 0; + const std::uint8_t bValue = (i < bSize) ? b[i] : 0; - if (aValue > bValue) - { - return true; - } + if (aValue > bValue) { + return true; } - - return false; } - [[nodiscard]] std::vector add( - const std::vector &a, - const std::vector &b) noexcept - { - // Time complexity O(iterations) - // Initialize the result vector to store the sum - std::vector result; - - // Get the max iterations based on the largest vector - const int iterations = std::max(a.size(), b.size()); + return false; +} - // Carry to handle overflow between bytes - std::uint16_t carry = 0; +[[nodiscard]] std::vector add(const std::vector &a, + const std::vector &b) noexcept { + // Time complexity O(iterations) + // Initialize the result vector to store the sum + std::vector result; - for (int i = 0; i < iterations; i++) - { - // Set up a holder for the sum - std::uint16_t sum = carry; + // Get the max iterations based on the largest vector + const int iterations = std::max(a.size(), b.size()); - // Only add to sum if we actually have a value in a - if (i < a.size()) - { - sum += a[i]; - } + // Carry to handle overflow between bytes + std::uint16_t carry = 0; - // Only add to sum if we actually have a value in b - if (i < b.size()) - { - sum += b[i]; - } + for (int i = 0; i < iterations; i++) { + // Set up a holder for the sum + std::uint16_t sum = carry; - // Calculate the carry which is the overflow beyond 255 - carry = sum >> 8; - - // Clear out everything except the lsb - // This is done by masking the sum with 255 - // For example - // 0000000101111111 (=383) - // 0000000011111111 (=255) - // ---------------- - // 0000000100000000 - sum &= 0xFF; - - // Append the least significant byte of the sum to the result - result.push_back(static_cast(sum)); + // Only add to sum if we actually have a value in a + if (i < a.size()) { + sum += a[i]; } - // If we got a carry we append this as well - if (carry) - { - result.push_back(static_cast(carry)); + // Only add to sum if we actually have a value in b + if (i < b.size()) { + sum += b[i]; } - return result; + // Calculate the carry which is the overflow beyond 255 + carry = sum >> 8; + + // Clear out everything except the lsb + sum &= 0xFF; + + // Append the least significant byte of the sum to the result + result.push_back(static_cast(sum)); } - // The return value can only be positive, if it would be negative, 0 is returned - [[nodiscard]] std::vector sub( - const std::vector &a, - const std::vector &b) noexcept - { - // Prevent from ending the subtraction before going over the hole subtractor and stop if the result can only be negative - if (getStartBitIndex(b) > getStartBitIndex(a)) return {0}; - - std::vector result; - - // Handle an underflow when subtracting - bool borrow = false; - - for (int i = 0; i < a.size(); i++) - { - std::int32_t subtract; - // Check for the end of the subtractor - if (i >= b.size()) { - subtract = borrow; - } else { - subtract = borrow + b[i]; - } - borrow = false; + // If we got a carry we append this as well + if (carry) { + result.push_back(static_cast(carry)); + } - if (a[i] >= subtract) { - // If the current number is as least as big as subtract - std::uint8_t number = a[i] - subtract; - result.push_back(number); - } else { - // Borrow from the next number - subtract -= 256; - borrow = true; + return result; +} - // Here subtract can only be 0 or negative - std::uint8_t number = a[i] - subtract; - result.push_back(number); - } +// The return value can only be positive, if it would be negative, 0 is returned +[[nodiscard]] std::vector sub(const std::vector &a, + const std::vector &b) noexcept { + // Prevent from ending the subtraction before going over the hole subtractor and stop if the + // result can only be negative + if (getStartBitIndex(b) > getStartBitIndex(a)) return {0}; + + std::vector result; + + // Handle an underflow when subtracting + bool borrow = false; + + for (int i = 0; i < a.size(); i++) { + std::int32_t subtract; + // Check for the end of the subtractor + if (i >= b.size()) { + subtract = borrow; + } else { + subtract = borrow + b[i]; + } + borrow = false; + + if (a[i] >= subtract) { + // If the current number is as least as big as subtract + std::uint8_t number = a[i] - subtract; + result.push_back(number); + } else { + // Borrow from the next number + subtract -= 256; + borrow = true; + + // Here subtract can only be 0 or negative + std::uint8_t number = a[i] - subtract; + result.push_back(number); } - - return result; } - [[nodiscard]] std::vector convertToVector(std::uint64_t number) noexcept { - - std::vector result; + return result; +} - // Loop until the whole number is zero - while (number) - { - // We only care for the lsb - result.push_back(static_cast(number & 0xFF)); +[[nodiscard]] std::vector convertToVector(std::uint64_t number) noexcept { + std::vector result; - // Shift the number by 8 to the right - number >>= 8; - } + // Loop until the whole number is zero + while (number) { + // We only care for the lsb + result.push_back(static_cast(number & 0xFF)); - return result; + // Shift the number by 8 to the right + number >>= 8; } - [[nodiscard]] std::vector mul( - const std::vector &a, - const std::vector &b) noexcept - { - // Time complexity O(aSize * bSize) - const std::uint64_t aSize = a.size(); - const std::uint64_t bSize = b.size(); + return result; +} - // Initialize the result vector with zeros, with the size of aSize + bSize - std::vector result(aSize + bSize, 0); +[[nodiscard]] std::vector mul(const std::vector &a, + const std::vector &b) noexcept { + // Time complexity O(aSize * bSize) + const std::uint64_t aSize = a.size(); + const std::uint64_t bSize = b.size(); - for (std::uint64_t i = 0; i < aSize; i++) - { - // Set up the carry value for each iteration - std::uint16_t carry = 0; + // Initialize the result vector with zeros, with the size of aSize + bSize + std::vector result(aSize + bSize, 0); - for (uint64_t x = 0; x < bSize; x++) - { - // Calculate the product by adding up the previous result, the carry, and the new product - std::uint16_t product = result[i + x] + carry + (a[i] * b[x]); + for (std::uint64_t i = 0; i < aSize; i++) { + // Set up the carry value for each iteration + std::uint16_t carry = 0; - // Calculate the carry which is the overflow beyond 255 - carry = product >> 8; + for (uint64_t x = 0; x < bSize; x++) { + // Calculate the product by adding up the previous result, the carry, and the new + // product + std::uint16_t product = result[i + x] + carry + (a[i] * b[x]); - // Store the least significant byte (lsb) of the product in the result - result[i + x] = static_cast(product & 0xFF); - } + // Calculate the carry which is the overflow beyond 255 + carry = product >> 8; - // If there is a carry, append it to the result - // But with the offset of bSize because of the previous inner loop - result[i + bSize] += static_cast(carry); + // Store the least significant byte (lsb) of the product in the result + result[i + x] = static_cast(product & 0xFF); } - return result; + // If there is a carry, append it to the result + // But with the offset of bSize because of the previous inner loop + result[i + bSize] += static_cast(carry); } - [[nodiscard]] std::vector div( - const std::vector dividend, - const std::vector &divisor, - std::vector *remaining = nullptr) noexcept -{ - std::vector quotient; - std::uint8_t quotientBuffer = 0; - std::uint16_t quotientBitIndex = 0; - - // The index of the last bit in the dividendMask inside dividend - std::int64_t dividendIndex = getStartBitIndex(dividend); - // This copy's the most significant bit of the dividend - std::vector dividendMask = addBitFromNumber({0}, dividend, dividendIndex--); - - while (dividendIndex >= -1) { - - if (quotientBitIndex > 7) { - /* The quotient is stored from the most significant byte to the least significant - * byte, in contrast to all the other vector based numbers. - * Which is later reversed, at the end of the function. */ - quotient.push_back(quotientBuffer); - quotientBuffer = 0; - quotientBitIndex = 0; - } + return result; +} - if (isEqual(dividendMask, divisor) || isBigger(dividendMask, divisor)) { +[[nodiscard]] std::vector div( + const std::vector dividend, const std::vector &divisor, + std::vector *remaining = nullptr) noexcept { + std::vector quotient; + std::uint8_t quotientBuffer = 0; + std::uint16_t quotientBitIndex = 0; + + // The index of the last bit in the dividendMask inside dividend + std::int64_t dividendIndex = getStartBitIndex(dividend); + // This copy's the most significant bit of the dividend + std::vector dividendMask = addBitFromNumber({0}, dividend, dividendIndex--); + + while (dividendIndex >= -1) { + if (quotientBitIndex > 7) { + // The quotient is stored from the most significant byte to the least significant + // byte, in contrast to all the other vector based numbers. + // Which is later reversed, at the end of the function. + quotient.push_back(quotientBuffer); + quotientBuffer = 0; + quotientBitIndex = 0; + } - /* Stop the loop if the dividend is smaller than the divisor, because fractional - * digits are not supported */ - if (dividendIndex < 0) { - quotientBuffer <<= 1; + if (isEqual(dividendMask, divisor) || isBigger(dividendMask, divisor)) { + // Stop the loop if the dividend is smaller than the divisor, because fractional digits + // are not supported + if (dividendIndex < 0) { + quotientBuffer <<= 1; - // If remaining pointer is passed, set the remaining value - if (remaining != nullptr) *remaining = dividendMask; - break; - } + // If remaining pointer is passed, set the remaining value + if (remaining != nullptr) *remaining = dividendMask; + break; + } - // Shift the dividend and set the new bit as high + // Shift the dividend and set the new bit as high + quotientBuffer <<= 1; + quotientBuffer++; + quotientBitIndex++; + + dividendMask = sub(dividendMask, divisor); + dividendMask = addBitFromNumber(dividendMask, dividend, dividendIndex--); + } else { + // Stop the loop if the dividend is smaller than the divisor, because fractional digits + // are not supported + if (dividendIndex < 0) { quotientBuffer <<= 1; - quotientBuffer++; - quotientBitIndex++; - dividendMask = sub(dividendMask, divisor); - dividendMask = addBitFromNumber(dividendMask, dividend, dividendIndex--); - } else { - /* Stop the loop if the dividend is smaller than the divisor, because fractional - * digits are not supported */ - if (dividendIndex < 0) { - quotientBuffer <<= 1; - - // If remaining pointer is passed, set the remaining value - if (remaining != nullptr) *remaining = dividendMask; - break; - } - - /* if the divisor is bigger than the dividend, we need to shift the dividend - * and set the new bit as low */ - quotientBuffer <<= 1; - quotientBitIndex++; + // If remaining pointer is passed, set the remaining value + if (remaining != nullptr) *remaining = dividendMask; + break; + } - /* Stop the loop if the dividend is smaller than the divisor, because fractional - * digits are not supported */ - if (dividendIndex < 0) break; + // if the divisor is bigger than the dividend, we need to shift the dividend and set the + // new bit as low + quotientBuffer <<= 1; + quotientBitIndex++; - // Because dividendMask is smaller than divisor, we add the next bit from the dividend - dividendMask = addBitFromNumber(dividendMask, dividend, dividendIndex--); - } + // Stop the loop if the dividend is smaller than the divisor, because fractional digits + // are not supported + if (dividendIndex < 0) break; + + // Because dividendMask is smaller than divisor, we add the next bit from the dividend + dividendMask = addBitFromNumber(dividendMask, dividend, dividendIndex--); } + } - if (quotientBuffer != 0) quotient.push_back(quotientBuffer); + if (quotientBuffer != 0) quotient.push_back(quotientBuffer); - // Reverse the quotient - std::reverse(quotient.begin(), quotient.end()); + // Reverse the quotient + std::reverse(quotient.begin(), quotient.end()); - return quotient; - } + return quotient; +} - [[nodiscard]] std::vector pow( - const std::vector &a, - const std::uint64_t &pow) noexcept - { - // Copy the value from a into result while keeping a constant - std::vector result; - std::copy(a.begin(), a.end(), std::back_inserter(result)); - - // Start the loop at 1, because the first number is already assigned to result - for (std::uint32_t i = 1; i < pow; i++) { - result = mul(result, a); - } +[[nodiscard]] std::vector pow(const std::vector &a, + const std::uint64_t &pow) noexcept { + // Copy the value from a into result while keeping a constant + std::vector result; + std::copy(a.begin(), a.end(), std::back_inserter(result)); - return result; + // Start the loop at 1, because the first number is already assigned to result + for (std::uint32_t i = 1; i < pow; i++) { + result = mul(result, a); } + + return result; } +} // namespace operations #endif