diff --git a/src/openrct2/core/Crypt.CNG.cpp b/src/openrct2/core/Crypt.CNG.cpp index 385089b788..20b6a7ff22 100644 --- a/src/openrct2/core/Crypt.CNG.cpp +++ b/src/openrct2/core/Crypt.CNG.cpp @@ -18,6 +18,8 @@ #define __USE_CNG__ #endif +#undef __USE_CNG__ + #ifdef __USE_CNG__ #include "Crypt.h" @@ -36,14 +38,14 @@ template class CngHashAlgorithm final : public TBase { private: - const char * _algName; + const wchar_t * _algName; BCRYPT_ALG_HANDLE _hAlg{}; BCRYPT_HASH_HANDLE _hHash{}; PBYTE _pbHashObject{}; bool _reusable{}; public: - CngHashAlgorithm(const char * algName) + CngHashAlgorithm(const wchar_t * algName) { // BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8 _algName = algName; @@ -56,7 +58,7 @@ public: Dispose(); } - HashAlgorithm * Clear() override + TBase * Clear() override { if (_reusable) { @@ -71,7 +73,7 @@ public: return this; } - HashAlgorithm * Update(const void * data, size_t dataLen) override + TBase * Update(const void * data, size_t dataLen) override { auto status = BCryptHashData(_hHash, (PBYTE)data, (ULONG)dataLen, 0); if (!NT_SUCCESS(status)) @@ -81,9 +83,9 @@ public: return this; } - Result Finish() override + typename TBase::Result Finish() override { - Result result; + typename TBase::Result result; auto status = BCryptFinishHash(_hHash, result.data(), (ULONG)result.size(), 0); if (!NT_SUCCESS(status)) { @@ -96,7 +98,7 @@ private: void Initialise() { auto flags = _reusable ? BCRYPT_HASH_REUSABLE_FLAG : 0; - auto status = BCryptOpenAlgorithmProvider(&_hAlg, TAlg, nullptr, flags); + auto status = BCryptOpenAlgorithmProvider(&_hAlg, _algName, nullptr, flags); if (!NT_SUCCESS(status)) { throw std::runtime_error("BCryptOpenAlgorithmProvider failed: " + std::to_string(status)); diff --git a/src/openrct2/core/Crypt.OpenSSL.cpp b/src/openrct2/core/Crypt.OpenSSL.cpp index 5df154bb9c..be59c08cf2 100644 --- a/src/openrct2/core/Crypt.OpenSSL.cpp +++ b/src/openrct2/core/Crypt.OpenSSL.cpp @@ -18,6 +18,7 @@ #define __USE_CNG__ #endif +#undef __USE_CNG__ #ifndef __USE_CNG__ #include "Crypt.h" @@ -25,6 +26,7 @@ #include #include #include +#include static void OpenSSLThrowOnBadStatus(const std::string_view& name, int status) { @@ -109,12 +111,100 @@ public: class OpenSSLRsaKey final : public RsaKey { public: - EVP_PKEY * const EvpKey{}; + EVP_PKEY * GetEvpKey() const { return _evpKey; } - void SetPrivate(const std::string_view& pem) override { } - void SetPublic(const std::string_view& pem) override { } - std::string GetPrivate() override { return ""; } - std::string GetPublic() override { return ""; } + void SetPrivate(const std::string_view& pem) override + { + SetKey(pem, true); + } + + void SetPublic(const std::string_view& pem) override + { + SetKey(pem, false); + } + + std::string GetPrivate() override { return GetKey(true); } + + std::string GetPublic() override { return GetKey(false); } + +private: + EVP_PKEY * _evpKey{}; + + void SetKey(const std::string_view& pem, bool isPrivate) + { + // Read PEM data via BIO buffer + auto bio = BIO_new_mem_buf(pem.data(), (int)pem.size()); + if (bio == nullptr) + { + throw std::runtime_error("BIO_new_mem_buf failed"); + } + auto rsa = isPrivate ? + PEM_read_bio_RSAPrivateKey(bio, nullptr, nullptr, nullptr) : + PEM_read_bio_RSAPublicKey(bio, nullptr, nullptr, nullptr); + if (rsa == nullptr) + { + BIO_free_all(bio); + auto msg = isPrivate ? + "PEM_read_bio_RSAPrivateKey failed" : + "PEM_read_bio_RSAPublicKey failed"; + throw std::runtime_error(msg); + } + BIO_free_all(bio); + + if (isPrivate && !RSA_check_key(rsa)) + { + RSA_free(rsa); + throw std::runtime_error("PEM key was invalid"); + } + + // Assign new key + EVP_PKEY_free(_evpKey); + _evpKey = EVP_PKEY_new(); + EVP_PKEY_set1_RSA(_evpKey, rsa); + RSA_free(rsa); + } + + std::string GetKey(bool isPrivate) + { + if (_evpKey == nullptr) + { + throw std::runtime_error("No key has been assigned"); + } + + auto rsa = EVP_PKEY_get1_RSA(_evpKey); + if (rsa == nullptr) + { + throw std::runtime_error("EVP_PKEY_get1_RSA failed"); + } + if (!RSA_check_key(rsa)) + { + RSA_free(rsa); + throw std::runtime_error("Loaded RSA key is invalid"); + } + + auto bio = BIO_new(BIO_s_mem()); + if (bio == nullptr) + { + throw std::runtime_error("BIO_new failed"); + } + + auto status = isPrivate ? + PEM_write_bio_RSAPrivateKey(bio, rsa, nullptr, nullptr, 0, nullptr, nullptr) : + PEM_write_bio_RSAPublicKey(bio, rsa); + if (status != 1) + { + BIO_free_all(bio); + RSA_free(rsa); + throw std::runtime_error("PEM_write_bio_RSAPrivateKey failed"); + } + RSA_free(rsa); + + auto keylen = BIO_pending(bio); + std::string result(keylen, 0); + BIO_read(bio, result.data(), keylen); + BIO_free_all(bio); + return result; + } }; class OpenSSLRsaAlgorithm final : public RsaAlgorithm @@ -122,7 +212,7 @@ class OpenSSLRsaAlgorithm final : public RsaAlgorithm public: std::vector SignData(const RsaKey& key, const void * data, size_t dataLen) override { - auto evpKey = static_cast(key).EvpKey; + auto evpKey = static_cast(key).GetEvpKey(); EVP_MD_CTX * mdctx{}; try { @@ -160,7 +250,7 @@ public: bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) override { - auto evpKey = static_cast(key).EvpKey; + auto evpKey = static_cast(key).GetEvpKey(); EVP_MD_CTX * mdctx{}; try { @@ -182,7 +272,7 @@ public: OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status); } EVP_MD_CTX_destroy(mdctx); - return status == 0; + return status == 1; } catch (const std::exception&) { diff --git a/src/openrct2/core/Crypt.h b/src/openrct2/core/Crypt.h index 387898c92e..e2c2ed6e21 100644 --- a/src/openrct2/core/Crypt.h +++ b/src/openrct2/core/Crypt.h @@ -60,7 +60,7 @@ namespace Hash std::unique_ptr CreateRSA(); std::unique_ptr CreateRSAKey(); - Sha1Algorithm::Result SHA1(const void * data, size_t dataLen) + inline Sha1Algorithm::Result SHA1(const void * data, size_t dataLen) { return CreateSHA1() ->Update(data, dataLen) diff --git a/test/tests/CryptTests.cpp b/test/tests/CryptTests.cpp index 911ba2035b..8213389e33 100644 --- a/test/tests/CryptTests.cpp +++ b/test/tests/CryptTests.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include class CryptTests : public testing::Test @@ -96,3 +98,35 @@ TEST_F(CryptTests, SHA1_Many) } AssertHash("ac46948f97d69fa766706e932ce82562b4f73aa7", alg->Finish()); } + +TEST_F(CryptTests, RSA_Basic) +{ + std::vector data = { 0, 1, 2, 3, 4, 5, 6, 7 }; + + auto file = File::ReadAllBytes("C:/Users/Ted/Documents/OpenRCT2/keys/Ted.privkey"); + auto key = Hash::CreateRSAKey(); + key->SetPrivate(std::string_view((const char *)file.data(), file.size())); + + auto rsa = Hash::CreateRSA(); + auto signature = rsa->SignData(*key, data.data(), data.size()); + bool verified = rsa->VerifyData(*key, data.data(), data.size(), signature.data(), signature.size()); + ASSERT_TRUE(verified); +} + +TEST_F(CryptTests, RSA_VerifyWithPublic) +{ + std::vector data = { 7, 6, 5, 4, 3, 2, 1, 0 }; + + auto privateFile = File::ReadAllBytes("C:/Users/Ted/Documents/OpenRCT2/keys/Ted.privkey"); + auto privateKey = Hash::CreateRSAKey(); + privateKey->SetPrivate(std::string_view((const char *)privateFile.data(), privateFile.size())); + + auto publicFile = File::ReadAllBytes("C:/Users/Ted/Documents/OpenRCT2/keys/Ted-f60af9b4ea83cd884238bcbeba8e11545e70d574.pubkey"); + auto publicKey = Hash::CreateRSAKey(); + publicKey->SetPublic(std::string_view((const char *)publicFile.data(), publicFile.size())); + + auto rsa = Hash::CreateRSA(); + auto signature = rsa->SignData(*privateKey, data.data(), data.size()); + bool verified = rsa->VerifyData(*publicKey, data.data(), data.size(), signature.data(), signature.size()); + ASSERT_TRUE(verified); +}