diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d76f4b3..a6aff9b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,7 @@ add_executable(RSA-Encryptor main.cpp utility.cpp - keys.cpp + key.cpp encryption.cpp cli.cpp ) diff --git a/src/cli.cpp b/src/cli.cpp index 7ece695..750ab30 100644 --- a/src/cli.cpp +++ b/src/cli.cpp @@ -1,18 +1,18 @@ #include #include "cli.h" -#include "keys.h" +#include "key.h" void cli::help() { std::cout << "Welcome to RSA-Encryptor" << std::endl; } -void cli::key::create() { - keys::createRSAKey(); +void cli::keyManager::create() { + key::createRSAKey(); } -void cli::key::list() { +void cli::keyManager::list() { } -void cli::key::print(const std::string& name, bool publicKey, bool privateKey) { +void cli::keyManager::print(const std::string& name, bool publicKey, bool privateKey) { } diff --git a/src/cli.h b/src/cli.h index 972d471..9b1ed5b 100644 --- a/src/cli.h +++ b/src/cli.h @@ -1,9 +1,14 @@ +#ifndef CLI_H +#define CLI_H + namespace cli { void help(); - namespace key { + namespace keyManager { void create(); void list(); void print(const std::string& name, bool publicKey, bool privateKey); } -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/src/encryption.h b/src/encryption.h index 00e9feb..d4b13b9 100644 --- a/src/encryption.h +++ b/src/encryption.h @@ -1,3 +1,6 @@ +#ifndef ENCRYPTION_H +#define ENCRYPTION_H + #include #include @@ -6,4 +9,6 @@ class encryption private: public: std::int8_t encrypt(std::uint8_t data, unsigned long int e, unsigned long N); -}; \ No newline at end of file +}; + +#endif \ No newline at end of file diff --git a/src/key.cpp b/src/key.cpp index 237d242..44b4b23 100644 --- a/src/key.cpp +++ b/src/key.cpp @@ -12,6 +12,7 @@ std::filesystem::path key::keysPath() { } int key::keyExists(std::string name) { + // TODO fix function not case sensitive /* 0: key doesn't exist * 1: only publicKey exists * 2: only privateKey exists @@ -39,29 +40,172 @@ int key::keyExists(std::string name) { return status; } -int key::writeKey(const std::string& name, const unsigned long int data, const bool isPublic) { +int key::writeKey(const std::string& name, std::vector *data, const bool isPublic) { + if (data == nullptr) return -1; + std::filesystem::path keysFolder = keysPath(); std::filesystem::path keyFile = keysFolder / (name + (isPublic? ".pub":"")); + // load file in memory std::ofstream outFile(keyFile); + // key in base64 format + std::string base64Data = base64Encode(*data); + if (outFile.is_open()) { - std::string dataStr = std::to_string(data); outFile << "-----BEGIN RSA " << (isPublic? "PUBLIC":"PRIVATE") << " KEY-----\n"; - outFile << base64_encode(dataStr) << "\n"; + for (int i = 0; i < base64Data.size(); i += 64) { + outFile << base64Data.substr(i, 64) << "\n"; + } outFile << "-----END RSA " << (isPublic? "PUBLIC":"PRIVATE") << " KEY-----\n"; outFile.close(); } else { std::cerr << "Could not write to file: " << keyFile << std::endl; - return 0; + return -1; } return 1; } -std::string key::base64_encode(std::string &data) { - //TODO Add base 64 encode function - return data; +int key::readKey(const std::string &name, std::vector *data, bool isPublic) { + std::filesystem::path keysFolder = keysPath(); + + std::filesystem::path keyFile = keysFolder / (name + (isPublic? ".pub":"")); + // load file in memory + std::ifstream inFile(keyFile); + + // check if file is loaded correctly + if (!inFile.is_open()) { + std::cerr << "Die Datei konnte nicht geladen werden!" << std::endl; + return -1; + } + + std::string currentLine; + std::string keyString; + + while (getline (inFile, currentLine)) { + + if (currentLine.empty()) continue; + if (currentLine.at(0) == '-') continue; + + keyString += currentLine; + } + + // remove any line breaks + for (int c = keyString.length() - 1; c >= 0; c--) { + if (keyString.at(c) == '\n') { + keyString.erase(c); + } + } + + *data = base64Decode(keyString); + + return 1; +} + +std::string key::base64Encode(const std::vector &data) { + // TODO the encoding might not work if the data size is not dividable by 3 + if (data.size() % 3 != 0) std::cerr << "Encoding data size not dividable by 3" << std::endl; + + std::vector result; + int index = 0; + + while (index < data.size()) { + + uint32_t dataSegment = 0; + + // take 3 numbers from the vector + for (int i = 2; i >= 0; i--) { + // prevent from reading outside the vector + if (index > data.size() - 1) break; + + uint8_t number = data.at(index++); + dataSegment += number << i * 8; + } + + // 4 times, put 6 bits from dataSegment into result vector + for (int j = 0; j < 4; j++) { + dataSegment <<= 6; + result.push_back(static_cast((dataSegment & 0b00111111 << 24) >> 24)); + } + } + + std::string outString; + + // convert the numeric value to the corresponding char of the base64 charset + for (auto c : result) { + outString += base64Chars[c]; + } + + return outString; +} + +std::vector key::base64Decode(std::string data) { + // TODO the decoding might not work if the data size is not dividable by 4 + if (data.size() % 4 != 0) std::cerr << "Decoding data size not dividable by 4" << std::endl; + + std::vector result; + + while (!data.empty()) { + // remove 4 chars from the string + std::string letterSegment = data.substr(0, 4); + data.erase(0, 4); + + uint32_t dataSegment = 0; + + // put 4 chars in one binary string + for (int i = 3; i >= 0; i--) { + // check if there are remaining letters + if (i > letterSegment.length() - 1) continue; + + const uint8_t number = getBase64Index(letterSegment.at(3 - i)); + + // check for valid char code + if (number > 0b111111) { + std::cerr << "trying to Decode not valid char: " << letterSegment.at(i) << std::endl; + continue; + } + dataSegment += number << i * 6; + } + + // read data from binary string + for (int j = 0; j < 3; j++) { + dataSegment <<= 8; + result.push_back((dataSegment & 0xFF << 24) >> 24); + } + } + return result; +} + +uint8_t key::getBase64Index(char letter) { + for (int i = 0; base64Chars[i] != '\0'; i++) { + if (base64Chars[i] == letter) { + return i; + } + } + return 0; +} + +int key::createKey(std::vector *keyPublic, std::vector *keyPrivate) { + *keyPublic = { + 23, 87, 45, 190, 12, 78, 34, 210, 56, 89, + 123, 67, 90, 150, 32, 76, 54, 200, 11, 99, + 101, 145, 67, 189, 43, 88, 29, 176, 58, 92, + 111, 134, 78, 201, 15, 84, 39, 220, 66, 97, + 105, 142, 71, 185, 49, 81, 27, 170, 53, 95, + 41 + }; + *keyPrivate = { + 34, 78, 123, 56, 89, 210, 45, 190, 12, 87, + 67, 150, 32, 76, 54, 200, 11, 99, 101, 145, + 67, 189, 43, 88, 29, 176, 58, 92, 111, 134, + 78, 201, 15, 84, 39, 220, 66, 97, 105, 142, + 71, 185, 49, 81, 27, 170, 53, 95, 102, 147, + 68, 191, 44, 85, 30, 177, 59, 93, 112, 135, + 79, 202, 16, 85, 40, 221, 67, 98, 23, 87, + 91, 65 + }; + return 1; } void key::createRSAKey() { @@ -84,7 +228,12 @@ void key::createRSAKey() { return; } - if (writeKey(keyName, 0, true) == 1 && writeKey(keyName, 0, false) == 1) { + auto* keyPublic = new std::vector(); + auto* keyPrivate = new std::vector(); + + createKey(keyPublic, keyPrivate); + + if (writeKey(keyName, keyPublic, true) == 1 && writeKey(keyName, keyPrivate, false) == 1) { std::cout << keyName + " key successfully created!\n"; std::cout << "You can find your newly created key here: " << keysFolder << std::endl; } else { @@ -92,10 +241,14 @@ void key::createRSAKey() { } } -std::pair getPrivateKey(std::string& name) { - return {647'090'566'899, 234'099'456'876'004}; +std::vector * key::getPrivateKey(std::string &name) { + std::vector* data = new std::vector; + readKey(name, data, false); + return data; } -std::pair getPublicKey(std::string& name) { - return {143'548'453'234, 234'099'456'876'004}; -} \ No newline at end of file +std::vector * key::getPublicKey(std::string &name) { + std::vector* data = new std::vector; + readKey(name, data, true); + return data; +} diff --git a/src/key.h b/src/key.h index 823741a..abdd808 100644 --- a/src/key.h +++ b/src/key.h @@ -1,5 +1,10 @@ +#ifndef KEY_H +#define KEY_H + #include #include +#include +#include #define KEY_FOLDER "rsa-keys" @@ -10,17 +15,29 @@ enum { BOTH }; +const char base64Chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + class key { private: static std::filesystem::path keysPath(); static int keyExists(std::string name); static void generatePublicFromPrivate(); - static int writeKey(const std::string &name, unsigned long int data, bool isPublic); - static std::string base64_encode(std::string &data); + + static int writeKey(const std::string &name, std::vector *data, bool isPublic); + static int readKey(const std::string &name, std::vector *data, bool isPublic); + + static std::string base64Encode(const std::vector &data); + static std::vector base64Decode(std::string data); + + static uint8_t getBase64Index(char letter); + + static int createKey(std::vector *keyPublic, std::vector *keyPrivate); public: static void createRSAKey(); - static std::pair getPrivateKey(std::string &name); - static std::pair getPublicKey(std::string &name); -}; \ No newline at end of file + static std::vector* getPrivateKey(std::string &name); + static std::vector* getPublicKey(std::string &name); +}; + +#endif \ No newline at end of file diff --git a/src/keys.cpp b/src/keys.cpp deleted file mode 100644 index a8a51e2..0000000 --- a/src/keys.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include "keys.h" - - -void keys::createRSAKey() { - std::string keyName; - - std::cout << "Enter the name of the key: "; - std::cin >> keyName; - - // Navigate to the root directory of the project - std::filesystem::path programRootPath = std::filesystem::current_path().parent_path(); - std::filesystem::path keysFolder = programRootPath / KEY_FOLDER; - - // Check if keys folder exists - if (!std::filesystem::exists(keysFolder)) { - // Create the "RSA-Keys" directory in the root folder fo the program - std::filesystem::create_directory(keysFolder); - } - - std::cout << "You can find your newly created key here: " << keysFolder << std::endl; -} - -std::pair getPrivateKey(std::string& name) { - return {647'090'566'899, 234'099'456'876'004}; -} - -std::pair getPublicKey(std::string& name) { - return {143'548'453'234, 234'099'456'876'004}; -} \ No newline at end of file diff --git a/src/keys.h b/src/keys.h deleted file mode 100644 index 670fc4d..0000000 --- a/src/keys.h +++ /dev/null @@ -1,13 +0,0 @@ -#include - -#define KEY_FOLDER "rsa-keys" - - -class keys { - -public: - static void createRSAKey(); - - std::pair getPrivateKey(std::string& name); - std::pair getPublicKey(std::string& name); -}; \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 07b69a1..15e2ebd 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -6,6 +6,7 @@ #include "utility.h" #include "cli.h" +#include "key.h" #include "vec/operations.h" int main(int argc, char* argv[]) { @@ -31,19 +32,17 @@ int main(int argc, char* argv[]) { } } - std::cout << arguments.publicKey << std::endl; - std::string feature = argv[1]; if (feature == "key" && argv[2] != nullptr) { std::string subFeature = argv[2]; if (subFeature == "create") { - cli::key::create(); + cli::keyManager::create(); } else if (subFeature == "list") { - cli::key::list(); - } else if (subFeature == "print") { - cli::key::print(argv[3], arguments.publicKey, arguments.privateKey); + cli::keyManager::list(); + } else if (subFeature == "print" && argv[3] != nullptr) { + cli::keyManager::print(argv[3], arguments.publicKey, arguments.privateKey); } } else if (feature == "help") { @@ -55,20 +54,5 @@ int main(int argc, char* argv[]) { cli::help(); } - // std::cout << util.checkForPrime(54557) << std::endl; - // std::cout << util.checkForPrime(29) << std::endl; - //std::vector num1 = {255, 255, 255}; // 16777215 in dezimal - //std::vector num2 = {0x01}; // Add one - - std::vector num1 = convertToVector(16777215); // 16777215 in dezimal - std::vector num2 = convertToVector(1); // Add one - - std::vector result = add(num1, num2); - - // This gives out the result as a hex number - for (auto it = result.rbegin(); it != result.rend(); ++it) { - printf("%02X", *it); - } - std::cout << std::endl; return 0; } diff --git a/src/utility.h b/src/utility.h index 7c835fa..29dd202 100644 --- a/src/utility.h +++ b/src/utility.h @@ -1,3 +1,6 @@ +#ifndef UTILITY_H +#define UTILITY_H + #include class utility @@ -10,3 +13,4 @@ class utility bool checkForPrime(int number); }; +#endif \ No newline at end of file diff --git a/src/vec/operations.h b/src/vec/operations.h index ed69c75..c76a995 100644 --- a/src/vec/operations.h +++ b/src/vec/operations.h @@ -1,76 +1,362 @@ +#ifndef VEC_OPERATIONS_H +#define VEC_OPERATIONS_H + #include #include #include -[[nodiscard]] std::vector add(const std::vector& a, const std::vector& b) { - // Initialize the result vector to store the sum - std::vector result; +namespace operations { + [[nodiscard]] std::uint64_t getStartBitIndex (const std::vector &a) { + std::uint64_t index = 0; + std::uint64_t bitsSince1 = 0; + + for (std::uint8_t number: a) { + // Skip 8 bits if all zero + if (number == 0) { + bitsSince1 += 8; + continue; + } + // Go over every bit in number + for (std::uint64_t i = 0; i < sizeof(std::uint8_t) * 8; i++) { + // Select the current bit using a bitmask + if ((number & 0b1 << i) != 0) { + // Increment the index by the bits since the last increment + index += bitsSince1 + 1; + bitsSince1 = 0; + } else { + bitsSince1++; + } + } + } + + // Returns position of the first bit needed for the number + return index - 1; + } + + /* This function picks one bit from pickNumber and places it at the least significant + * position of number. */ + [[nodiscard]] std::vector addBitFromNumber (const std::vector &number, const std::vector &pickNumber, std::uint32_t index) { + std::vector result; + + // Check if the index is valid + if (index > getStartBitIndex(pickNumber)) { + return {0}; + } + + // Check if the bit at the given index is set + bool mostSignificantBit = pickNumber[index / 8] & 0b1 << index % 8; + + for (std::uint8_t currentByte : number) { + result.push_back((currentByte << 1) + mostSignificantBit); + + // Check if the most significant bit of the current byte is set + mostSignificantBit = currentByte & 0b1000000 == 0b10000000; + } + + /* In case the length of the vector and the actual number length matches, the number + * increases by one vector element with value 1 */ + if (mostSignificantBit) result.push_back(0b1); + + return result; + } - // Get the max iterations based on the largest vector - int iterations = std::max(a.size(), b.size()); + [[nodiscard]] bool isZero (const std::vector &a) { + for (std::uint8_t number: a) { + if (number != 0) return false; + } - // Carry to handle overflow between bytes - std::uint16_t carry = 0; + return true; + } - for (int i = 0; i < iterations; i++) + [[nodiscard]] bool isEqual ( + const std::vector &a, + const std::vector &b) noexcept { - // Set up a holder for the sum - std::uint16_t sum = carry; + const std::uint32_t aSize = a.size(); + const std::uint32_t bSize = b.size(); + const std::uint64_t iterations = aSize > bSize ? aSize : bSize; + + // In case both numbers are the same length these variables could point ot the same vector + const std::uint32_t shortest = aSize < bSize ? aSize : bSize; + const std::vector longest = aSize > bSize ? a : b; + + for (std::uint64_t i = 0; i < iterations; i++) { + if (i >= shortest) { + if (longest[i] != 0) return false; + } else { + if (a[i] != b[i]) return false; + } + } + + return true; + } - // Only add to sum if we actually have a value in a - if (i < a.size()) + [[nodiscard]] bool isBigger ( + const std::vector &a, + const std::vector &b) noexcept + { + const std::uint32_t aSize = a.size(); + const std::uint32_t bSize = b.size(); + const std::uint64_t iterations = aSize > bSize ? aSize : bSize; + + // We loop reverse throw all the vectors to catch leading zeros + for (std::int64_t i = iterations - 1; i >= 0; i--) { + + const std::uint8_t aValue = (i < aSize) ? a[i] : 0; + const std::uint8_t bValue = (i < bSize) ? b[i] : 0; + + if (aValue > bValue) + { + return true; + } + } + + return false; + } + + [[nodiscard]] std::vector add( + const std::vector &a, + const std::vector &b) noexcept + { + // Time complexity O(iterations) + // Initialize the result vector to store the sum + std::vector result; + + // Get the max iterations based on the largest vector + const int iterations = std::max(a.size(), b.size()); + + // Carry to handle overflow between bytes + std::uint16_t carry = 0; + + for (int i = 0; i < iterations; i++) + { + // Set up a holder for the sum + std::uint16_t sum = carry; + + // Only add to sum if we actually have a value in a + if (i < a.size()) + { + sum += a[i]; + } + + // Only add to sum if we actually have a value in b + if (i < b.size()) + { + sum += b[i]; + } + + // Calculate the carry which is the overflow beyond 255 + carry = sum >> 8; + + // Clear out everything except the lsb + // This is done by masking the sum with 255 + // For example + // 0000000101111111 (=383) + // 0000000011111111 (=255) + // ---------------- + // 0000000100000000 + sum &= 0xFF; + + // Append the least significant byte of the sum to the result + result.push_back(static_cast(sum)); + } + + // If we got a carry we append this as well + if (carry) { - sum += a[i]; + result.push_back(static_cast(carry)); } - - // Only add to sum if we actually have a value in b - if (i < b.size()) + + return result; + } + + // The return value can only be positive, if it would be negative, 0 is returned + [[nodiscard]] std::vector sub( + const std::vector &a, + const std::vector &b) noexcept + { + // Prevent from ending the subtraction before going over the hole subtractor and stop if the result can only be negative + if (getStartBitIndex(b) > getStartBitIndex(a)) return {0}; + + std::vector result; + + // Handle an underflow when subtracting + bool borrow = false; + + for (int i = 0; i < a.size(); i++) { - sum += b[i]; + std::int32_t subtract; + // Check for the end of the subtractor + if (i >= b.size()) { + subtract = borrow; + } else { + subtract = borrow + b[i]; + } + borrow = false; + + if (a[i] >= subtract) { + // If the current number is as least as big as subtract + std::uint8_t number = a[i] - subtract; + result.push_back(number); + } else { + // Borrow from the next number + subtract -= 256; + borrow = true; + + // Here subtract can only be 0 or negative + std::uint8_t number = a[i] - subtract; + result.push_back(number); + } } - // Calculate the carry for the next byte (overflow beyond 255) - carry = sum >> 8; + return result; + } + + [[nodiscard]] std::vector convertToVector(std::uint64_t number) noexcept { + + std::vector result; - // Clear out everything except the lsb - // This is done by masking the sum with 255 - // For example - // 0000000101111111 (=383) - // 0000000011111111 (=255) - // ---------------- - // 0000000100000000 - sum &= 0xFF; + // Loop until the whole number is zero + while (number) + { + // We only care for the lsb + result.push_back(static_cast(number & 0xFF)); + + // Shift the number by 8 to the right + number >>= 8; + } - // Append the least significant byte of the sum to the result - result.push_back(static_cast(sum)); + return result; } - // If we got a carry we append this as well - if (carry) + [[nodiscard]] std::vector mul( + const std::vector &a, + const std::vector &b) noexcept { - result.push_back(static_cast(carry)); + // Time complexity O(aSize * bSize) + const std::uint64_t aSize = a.size(); + const std::uint64_t bSize = b.size(); + + // Initialize the result vector with zeros, with the size of aSize + bSize + std::vector result(aSize + bSize, 0); + + for (std::uint64_t i = 0; i < aSize; i++) + { + // Set up the carry value for each iteration + std::uint16_t carry = 0; + + for (uint64_t x = 0; x < bSize; x++) + { + // Calculate the product by adding up the previous result, the carry, and the new product + std::uint16_t product = result[i + x] + carry + (a[i] * b[x]); + + // Calculate the carry which is the overflow beyond 255 + carry = product >> 8; + + // Store the least significant byte (lsb) of the product in the result + result[i + x] = static_cast(product & 0xFF); + } + + // If there is a carry, append it to the result + // But with the offset of bSize because of the previous inner loop + result[i + bSize] += static_cast(carry); + } + + return result; } - return result; -} + [[nodiscard]] std::vector div( + const std::vector dividend, + const std::vector &divisor, + std::vector *remaining = nullptr) noexcept +{ + std::vector quotient; + std::uint8_t quotientBuffer = 0; + std::uint16_t quotientBitIndex = 0; + + // The index of the last bit in the dividendMask inside dividend + std::int64_t dividendIndex = getStartBitIndex(dividend); + // This copy's the most significant bit of the dividend + std::vector dividendMask = addBitFromNumber({0}, dividend, dividendIndex--); + + while (dividendIndex >= -1) { + + if (quotientBitIndex > 7) { + /* The quotient is stored from the most significant byte to the least significant + * byte, in contrast to all the other vector based numbers. + * Which is later reversed, at the end of the function. */ + quotient.push_back(quotientBuffer); + quotientBuffer = 0; + quotientBitIndex = 0; + } + + if (isEqual(dividendMask, divisor) || isBigger(dividendMask, divisor)) { + + /* Stop the loop if the dividend is smaller than the divisor, because fractional + * digits are not supported */ + if (dividendIndex < 0) { + quotientBuffer <<= 1; + + // If remaining pointer is passed, set the remaining value + if (remaining != nullptr) *remaining = dividendMask; + break; + } + + // Shift the dividend and set the new bit as high + quotientBuffer <<= 1; + quotientBuffer++; + quotientBitIndex++; + + dividendMask = sub(dividendMask, divisor); + dividendMask = addBitFromNumber(dividendMask, dividend, dividendIndex--); + } else { + /* Stop the loop if the dividend is smaller than the divisor, because fractional + * digits are not supported */ + if (dividendIndex < 0) { + quotientBuffer <<= 1; + + // If remaining pointer is passed, set the remaining value + if (remaining != nullptr) *remaining = dividendMask; + break; + } -[[nodiscard]] std::vector convertToVector(std::uint64_t number) { - - std::vector result; + /* if the divisor is bigger than the dividend, we need to shift the dividend + * and set the new bit as low */ + quotientBuffer <<= 1; + quotientBitIndex++; - // Loop until the whole number is zero - while (number) { - // We only care for the lsb - result.push_back(static_cast(number & 0xFF)); + /* Stop the loop if the dividend is smaller than the divisor, because fractional + * digits are not supported */ + if (dividendIndex < 0) break; - // Shift the number by 8 to the right - number >>= 8; + // Because dividendMask is smaller than divisor, we add the next bit from the dividend + dividendMask = addBitFromNumber(dividendMask, dividend, dividendIndex--); + } + } + + if (quotientBuffer != 0) quotient.push_back(quotientBuffer); + + // Reverse the quotient + std::reverse(quotient.begin(), quotient.end()); + + return quotient; } - return result; -} + [[nodiscard]] std::vector pow( + const std::vector &a, + const std::uint64_t &pow) noexcept + { + // Copy the value from a into result while keeping a constant + std::vector result; + std::copy(a.begin(), a.end(), std::back_inserter(result)); + + // Start the loop at 1, because the first number is already assigned to result + for (std::uint32_t i = 1; i < pow; i++) { + result = mul(result, a); + } -[[nodiscard]] std::vector mul(const std::vector& a, const std::vector& b) { - std::vector result; - return result; + return result; + } } + +#endif