diff --git a/src/openrct2/core/Crypt.CNG.cpp b/src/openrct2/core/Crypt.CNG.cpp deleted file mode 100644 index d476d652e6..0000000000 --- a/src/openrct2/core/Crypt.CNG.cpp +++ /dev/null @@ -1,572 +0,0 @@ -#pragma region Copyright (c) 2018 OpenRCT2 Developers -/***************************************************************************** -* OpenRCT2, an open source clone of Roller Coaster Tycoon 2. -* -* OpenRCT2 is the work of many authors, a full list can be found in contributors.md -* For more information, visit https://github.com/OpenRCT2/OpenRCT2 -* -* OpenRCT2 is free software: you can redistribute it and/or modify -* it under the terms of the GNU General Public License as published by -* the Free Software Foundation, either version 3 of the License, or -* (at your option) any later version. -* -* A full copy of the GNU General Public License can be found in licence.txt -*****************************************************************************/ -#pragma endregion - -#if defined(_WIN32) && !defined(__USE_OPENSSL__) -#define __USE_CNG__ -#endif - -#ifdef __USE_CNG__ - -#include "Crypt.h" -#include "../platform/Platform2.h" -#include "IStream.hpp" -#include -#include -#include -#include - -// CNG: Cryptography API: Next Generation (CNG) -// available in Windows Vista onwards. -#define NOMINMAX -#define WIN32_LEAN_AND_MEAN -#include -#include -#include -#include -#define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) - -static void CngThrowOnBadStatus(const std::string_view& name, NTSTATUS status) -{ - if (!NT_SUCCESS(status)) - { - throw std::runtime_error(std::string(name) + " failed: " + std::to_string(status)); - } -} - -static void ThrowBadAllocOnNull(const void * ptr) -{ - if (ptr == nullptr) - { - throw std::bad_alloc(); - } -} - -template -class CngHashAlgorithm final : public TBase -{ -private: - const wchar_t * _algName; - BCRYPT_ALG_HANDLE _hAlg{}; - BCRYPT_HASH_HANDLE _hHash{}; - PBYTE _pbHashObject{}; - bool _reusable{}; - -public: - CngHashAlgorithm(const wchar_t * algName) - { - // BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8 - _algName = algName; - _reusable = Platform::IsOSVersionAtLeast(6, 2, 0); - Initialise(); - } - - ~CngHashAlgorithm() - { - Dispose(); - } - - TBase * Clear() override - { - if (_reusable) - { - // Finishing the current digest clears the state ready for a new digest - Finish(); - } - else - { - Dispose(); - Initialise(); - } - return this; - } - - TBase * Update(const void * data, size_t dataLen) override - { - auto status = BCryptHashData(_hHash, (PBYTE)data, (ULONG)dataLen, 0); - CngThrowOnBadStatus("BCryptHashData", status); - return this; - } - - typename TBase::Result Finish() override - { - typename TBase::Result result; - auto status = BCryptFinishHash(_hHash, result.data(), (ULONG)result.size(), 0); - CngThrowOnBadStatus("BCryptFinishHash", status); - return result; - } - -private: - void Initialise() - { - auto flags = _reusable ? BCRYPT_HASH_REUSABLE_FLAG : 0; - auto status = BCryptOpenAlgorithmProvider(&_hAlg, _algName, nullptr, flags); - CngThrowOnBadStatus("BCryptOpenAlgorithmProvider", status); - - // Calculate the size of the buffer to hold the hash object - DWORD cbHashObject{}; - DWORD cbData{}; - status = BCryptGetProperty(_hAlg, BCRYPT_OBJECT_LENGTH, (PBYTE)&cbHashObject, sizeof(DWORD), &cbData, 0); - CngThrowOnBadStatus("BCryptGetProperty", status); - - // Create a hash - _pbHashObject = (PBYTE)HeapAlloc(GetProcessHeap(), 0, cbHashObject); - ThrowBadAllocOnNull(_pbHashObject); - status = BCryptCreateHash(_hAlg, &_hHash, _pbHashObject, cbHashObject, nullptr, 0, 0); - CngThrowOnBadStatus("BCryptCreateHash", status); - } - - void Dispose() - { - BCryptCloseAlgorithmProvider(_hAlg, 0); - BCryptDestroyHash(_hHash); - HeapFree(GetProcessHeap(), 0, _pbHashObject); - - _hAlg = {}; - _hHash = {}; - _pbHashObject = {}; - } -}; - -class DerReader -{ -private: - ivstream _stream; - - template - T Read(std::istream& stream) - { - T value; - stream.read((char*)&value, sizeof(T)); - return value; - } - - template - std::vector Read(std::istream& stream, size_t count) - { - std::vector values(count); - stream.read((char*)values.data(), sizeof(T) * count); - return values; - } - - int ReadTag(std::istream& stream) - { - auto a = Read(stream); - // auto tagClass = a >> 6; - // auto tagConstructed = ((a & 0x20) != 0); - auto tagNumber = a & 0x1F; - if (tagNumber == 0x1F) - { - throw std::runtime_error("Unsupported DER tag"); - } - return tagNumber; - } - - int ReadLength(std::istream& stream) - { - auto a = Read(stream); - auto len = a & 0x7F; - if (len == a) - { - return len; - } - if (len > 6) - { - throw std::runtime_error("Length over 48 bits not supported at this position"); - } - if (len == 0) - { - throw std::runtime_error("Unknown length"); - } - auto result = 0; - for (auto i = 0; i < len; i++) - { - result = (result << 8) + Read(stream); - } - return result; - } - -public: - DerReader(const std::vector& data) - : _stream(data) - { - } - - void ReadSequenceHeader() - { - auto a = Read(_stream); - if (a == 0x8130) - { - Read(_stream); - } - else if (a == 0x8230) - { - Read(_stream); - } - else - { - throw std::runtime_error("Invalid DER code"); - } - } - - std::vector ReadInteger() - { - auto t = ReadTag(_stream); - if (t != 2) - { - throw std::runtime_error("Expected INTEGER"); - } - auto len = ReadLength(_stream); - auto result = Read(_stream, len); - - auto v = result[0]; - auto neg = (v > 127); - auto pad = neg ? 255 : 0; - for (size_t i = 0; i < result.size(); i++) - { - if (result[i] != pad) - { - result.erase(result.begin(), result.begin() + i); - break; - } - } - return result; - } -}; - -class DerWriter -{ -private: - std::vector _buffer; - -public: - void WriteSequenceHeader() - { - _buffer.push_back(0x30); - _buffer.push_back(0x81); - _buffer.push_back(0x89); - } - - void WriteInteger(const std::vector& data) - { - if (data.size() < 128) - { - _buffer.push_back((uint8_t)data.size()); - } - else if (data.size() <= std::numeric_limits().max()) - { - _buffer.push_back(0b10000001); - _buffer.push_back((uint8_t)data.size()); - } - else if (data.size() <= std::numeric_limits().max()) - { - _buffer.push_back(0b10000010); - _buffer.push_back((data.size() >> 8) & 0xFF); - _buffer.push_back(data.size() & 0xFF); - } - _buffer.insert(_buffer.end(), data.begin(), data.end()); - } - - std::vector&& Complete() - { - return std::move(_buffer); - } -}; - -class CngRsaKey final : public RsaKey -{ -private: - struct RsaKeyParams - { - std::vector Modulus; - std::vector Exponent; - std::vector Prime1; - std::vector Prime2; - }; - -public: - NCRYPT_KEY_HANDLE GetKeyHandle() const { return _hKey; } - - ~CngRsaKey() - { - NCryptFreeObject(_hKey); - } - - void SetPrivate(const std::string_view& pem) override - { - auto der = ReadPEM(pem, SZ_PRIVATE_BEGIN_TOKEN, SZ_PRIVATE_END_TOKEN); - DerReader derReader(der); - RsaKeyParams params; - derReader.ReadSequenceHeader(); - derReader.ReadInteger(); - params.Modulus = derReader.ReadInteger(); - params.Exponent = derReader.ReadInteger(); - derReader.ReadInteger(); - params.Prime1 = derReader.ReadInteger(); - params.Prime2 = derReader.ReadInteger(); - _hKey = ImportKey(params); - } - - void SetPublic(const std::string_view& pem) override - { - auto der = ReadPEM(pem, SZ_PUBLIC_BEGIN_TOKEN, SZ_PUBLIC_END_TOKEN); - DerReader derReader(der); - RsaKeyParams params; - derReader.ReadSequenceHeader(); - params.Modulus = derReader.ReadInteger(); - params.Exponent = derReader.ReadInteger(); - _hKey = ImportKey(params); - } - - std::string GetPrivate() override - { - return ""; - } - - std::string GetPublic() override - { - auto params = ExportKey(true); - DerWriter derWriter; - derWriter.WriteSequenceHeader(); - derWriter.WriteInteger(params.Modulus); - derWriter.WriteInteger(params.Exponent); - auto derBytes = derWriter.Complete(); - auto b64 = EncodeBase64(derBytes); - - std::ostringstream sb; - sb << std::string(SZ_PUBLIC_BEGIN_TOKEN) << std::endl; - sb << b64 << std::endl; - sb << std::string(SZ_PUBLIC_END_TOKEN) << std::endl; - return sb.str(); - } - -private: - static constexpr std::string_view SZ_PUBLIC_BEGIN_TOKEN = "-----BEGIN RSA PUBLIC KEY-----"; - static constexpr std::string_view SZ_PUBLIC_END_TOKEN = "-----END RSA PUBLIC KEY-----"; - static constexpr std::string_view SZ_PRIVATE_BEGIN_TOKEN = "-----BEGIN RSA PRIVATE KEY-----"; - static constexpr std::string_view SZ_PRIVATE_END_TOKEN = "-----END RSA PRIVATE KEY-----"; - - NCRYPT_KEY_HANDLE _hKey{}; - - static std::vector ReadPEM(const std::string_view& pem, const std::string_view& beginToken, const std::string_view& endToken) - { - auto beginPos = pem.find(beginToken); - auto endPos = pem.find(endToken); - if (beginPos != std::string::npos && endPos != std::string::npos) - { - beginPos += beginToken.size(); - auto code = Trim(pem.substr(beginPos, endPos - beginPos)); - return DecodeBase64(code); - } - throw std::runtime_error("Invalid PEM file"); - } - - static std::string_view Trim(std::string_view input) - { - for (size_t i = 0; i < input.size(); i++) - { - if (input[i] >= '!') - { - input.remove_prefix(i); - break; - } - } - for (size_t i = input.size() - 1; i >= 0; i--) - { - if (input[i] >= '!') - { - input = input.substr(0, i + 1); - break; - } - } - return input; - } - - static std::string EncodeBase64(const std::vector& input) - { - DWORD chString; - if (!CryptBinaryToStringA(input.data(), (DWORD)input.size(), CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, NULL, &chString)) - { - throw std::runtime_error("CryptBinaryToStringA failed"); - } - std::string result(chString, 0); - if (!CryptBinaryToStringA(input.data(), (DWORD)input.size(), CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, result.data(), &chString)) - { - throw std::runtime_error("CryptBinaryToStringA failed"); - } - return result; - } - - static std::vector DecodeBase64(const std::string_view& input) - { - DWORD cbBinary; - if (!CryptStringToBinaryA(input.data(), (DWORD)input.size(), CRYPT_STRING_BASE64, NULL, &cbBinary, NULL, NULL)) - { - throw std::runtime_error("CryptStringToBinaryA failed"); - } - std::vector result(cbBinary); - if (!CryptStringToBinaryA(input.data(), (DWORD)input.size(), CRYPT_STRING_BASE64, result.data(), &cbBinary, NULL, NULL)) - { - throw std::runtime_error("CryptStringToBinaryA failed"); - } - return result; - } - - static NCRYPT_KEY_HANDLE ImportKey(const RsaKeyParams& params) - { - bool isPublic = params.Prime1.size() == 0; - auto blobType = isPublic ? BCRYPT_RSAPUBLIC_BLOB : BCRYPT_RSAPRIVATE_BLOB; - - BCRYPT_RSAKEY_BLOB header{}; - header.Magic = isPublic ? BCRYPT_RSAPUBLIC_MAGIC : BCRYPT_RSAPRIVATE_MAGIC; - header.BitLength = (ULONG)(params.Modulus.size() * 8); - header.cbPublicExp = (ULONG)params.Exponent.size(); - header.cbModulus = (ULONG)params.Modulus.size(); - header.cbPrime1 = (ULONG)params.Prime1.size(); - header.cbPrime2 = (ULONG)params.Prime2.size(); - - std::vector blob; - blob.insert(blob.end(), (uint8_t*)&header, (uint8_t*)(&header + 1)); - blob.insert(blob.end(), params.Exponent.begin(), params.Exponent.end()); - blob.insert(blob.end(), params.Modulus.begin(), params.Modulus.end()); - blob.insert(blob.end(), params.Prime1.begin(), params.Prime1.end()); - blob.insert(blob.end(), params.Prime2.begin(), params.Prime2.end()); - - NCRYPT_PROV_HANDLE hProv{}; - NCRYPT_KEY_HANDLE hKey{}; - auto status = NCryptOpenStorageProvider(&hProv, MS_KEY_STORAGE_PROVIDER, 0); - CngThrowOnBadStatus("NCryptOpenStorageProvider", status); - status = NCryptImportKey(hProv, NULL, blobType, NULL, &hKey, (PBYTE)blob.data(), (DWORD)blob.size(), 0); - NCryptFreeObject(hProv); - CngThrowOnBadStatus("NCryptImportKey", status); - return hKey; - } - - RsaKeyParams ExportKey(bool onlyPublic) - { - auto blobType = onlyPublic ? BCRYPT_RSAPUBLIC_BLOB : BCRYPT_RSAPRIVATE_BLOB; - - std::vector output; - NCRYPT_PROV_HANDLE hProv{}; - try - { - auto status = NCryptOpenStorageProvider(&hProv, MS_KEY_STORAGE_PROVIDER, 0); - CngThrowOnBadStatus("NCryptOpenStorageProvider", status); - DWORD cbOutput{}; - status = NCryptExportKey(hProv, _hKey, blobType, NULL, NULL, 0, &cbOutput, 0); - CngThrowOnBadStatus("NCryptExportKey", status); - output = std::vector(cbOutput); - status = NCryptExportKey(hProv, _hKey, blobType, NULL, output.data(), cbOutput, NULL, 0); - CngThrowOnBadStatus("NCryptExportKey", status); - NCryptFreeObject(hProv); - } - catch (const std::exception&) - { - NCryptFreeObject(hProv); - } - - RsaKeyParams params; - const auto& header = *((BCRYPT_RSAKEY_BLOB*)output.data()); - size_t i = sizeof(BCRYPT_RSAKEY_BLOB); - params.Modulus.insert(params.Modulus.end(), output.begin() + i, output.begin() + i + header.cbModulus); - i += header.cbModulus; - params.Exponent.insert(params.Exponent.end(), output.begin() + i, output.begin() + i + header.cbPublicExp); - return params; - } -}; - -class CngRsaAlgorithm final : public RsaAlgorithm -{ -public: - std::vector SignData(const RsaKey& key, const void * data, size_t dataLen) override - { - auto hKey = static_cast(key).GetKeyHandle(); - auto [cbHash, pbHash] = HashData(data, dataLen); - auto [cbSignature, pbSignature] = std::tuple(); - try - { - BCRYPT_PKCS1_PADDING_INFO paddingInfo{ BCRYPT_SHA256_ALGORITHM }; - auto status = NCryptSignHash(hKey, &paddingInfo, pbHash, cbHash, NULL, 0, &cbSignature, BCRYPT_PAD_PKCS1); - CngThrowOnBadStatus("NCryptSignHash", status); - pbSignature = (PBYTE)HeapAlloc(GetProcessHeap(), 0, cbSignature); - ThrowBadAllocOnNull(pbSignature); - status = NCryptSignHash(hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, &cbSignature, BCRYPT_PAD_PKCS1); - CngThrowOnBadStatus("NCryptSignHash", status); - - auto result = std::vector(pbSignature, pbSignature + cbSignature); - HeapFree(GetProcessHeap(), 0, pbSignature); - return result; - } - catch (std::exception&) - { - HeapFree(GetProcessHeap(), 0, pbHash); - HeapFree(GetProcessHeap(), 0, pbSignature); - throw; - } - } - - bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) override - { - auto hKey = static_cast(key).GetKeyHandle(); - auto [cbHash, pbHash] = HashData(data, dataLen); - auto [cbSignature, pbSignature] = ToHeap(sig, sigLen); - - BCRYPT_PKCS1_PADDING_INFO paddingInfo { BCRYPT_SHA256_ALGORITHM }; - auto status = NCryptVerifySignature(hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, BCRYPT_PAD_PKCS1); - HeapFree(GetProcessHeap(), 0, pbSignature); - return status == ERROR_SUCCESS; - } - -private: - static std::tuple HashData(const void * data, size_t dataLen) - { - auto hash = Hash::SHA256(data, dataLen); - return ToHeap(hash.data(), hash.size()); - } - - static std::tuple ToHeap(const void * data, size_t dataLen) - { - auto cbHash = (DWORD)dataLen; - auto pbHash = (PBYTE)HeapAlloc(GetProcessHeap(), 0, dataLen); - ThrowBadAllocOnNull(pbHash); - std::memcpy(pbHash, data, dataLen); - return std::make_tuple(cbHash, pbHash); - } -}; - -namespace Hash -{ - std::unique_ptr CreateSHA1() - { - return std::make_unique>(BCRYPT_SHA1_ALGORITHM); - } - - std::unique_ptr CreateSHA256() - { - return std::make_unique>(BCRYPT_SHA256_ALGORITHM); - } - - std::unique_ptr CreateRSA() - { - return std::make_unique(); - } - - std::unique_ptr CreateRSAKey() - { - return std::make_unique(); - } -} - -#endif diff --git a/src/openrct2/core/Crypt.OpenSSL.cpp b/src/openrct2/core/Crypt.OpenSSL.cpp index 4ab9c80162..31c9460f78 100644 --- a/src/openrct2/core/Crypt.OpenSSL.cpp +++ b/src/openrct2/core/Crypt.OpenSSL.cpp @@ -14,11 +14,7 @@ *****************************************************************************/ #pragma endregion -#if defined(_WIN32) && !defined(__USE_OPENSSL__) -#define __USE_CNG__ -#endif - -#ifndef __USE_CNG__ +#ifndef DISABLE_NETWORK #include "Crypt.h" #include @@ -27,6 +23,8 @@ #include #include +using namespace Crypt; + static void OpenSSLThrowOnBadStatus(const std::string_view& name, int status) { if (status != 1) @@ -35,6 +33,16 @@ static void OpenSSLThrowOnBadStatus(const std::string_view& name, int status) } } +static void OpenSSLInitialise() +{ + static bool _opensslInitialised = false; + if (!_opensslInitialised) + { + _opensslInitialised = true; + OpenSSL_add_all_algorithms(); + } +} + template class OpenSSLHashAlgorithm final : public TBase { @@ -117,6 +125,41 @@ public: EVP_PKEY_free(_evpKey); } + void Generate() override + { + auto ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr); + if (ctx == nullptr) + { + throw std::runtime_error("EVP_PKEY_CTX_new_id failed"); + } + + try + { + auto status = EVP_PKEY_CTX_set_rsa_keygen_bits(ctx, 2048); + if (status == 0) + { + throw std::runtime_error("EVP_PKEY_CTX_set_rsa_keygen_bits failed"); + } + + status = EVP_PKEY_keygen_init(ctx); + OpenSSLThrowOnBadStatus("EVP_PKEY_keygen_init", status); + + EVP_PKEY * key{}; + status = EVP_PKEY_keygen(ctx, &key); + OpenSSLThrowOnBadStatus("EVP_PKEY_keygen", status); + + EVP_PKEY_free(_evpKey); + _evpKey = key; + + EVP_PKEY_CTX_free(ctx); + } + catch (const std::exception&) + { + EVP_PKEY_CTX_free(ctx); + throw; + } + } + void SetPrivate(const std::string_view& pem) override { SetKey(pem, true); @@ -180,11 +223,6 @@ private: { throw std::runtime_error("EVP_PKEY_get1_RSA failed"); } - if (!RSA_check_key(rsa)) - { - RSA_free(rsa); - throw std::runtime_error("Loaded RSA key is invalid"); - } auto bio = BIO_new(BIO_s_mem()); if (bio == nullptr) @@ -286,27 +324,31 @@ public: } }; -namespace Hash +namespace Crypt { std::unique_ptr CreateSHA1() { + OpenSSLInitialise(); return std::make_unique>(EVP_sha1()); } std::unique_ptr CreateSHA256() { + OpenSSLInitialise(); return std::make_unique>(EVP_sha256()); } std::unique_ptr CreateRSA() { + OpenSSLInitialise(); return std::make_unique(); } std::unique_ptr CreateRSAKey() { + OpenSSLInitialise(); return std::make_unique(); } } -#endif +#endif // DISABLE_NETWORK diff --git a/src/openrct2/core/Crypt.h b/src/openrct2/core/Crypt.h index 029442fedf..178b431ae3 100644 --- a/src/openrct2/core/Crypt.h +++ b/src/openrct2/core/Crypt.h @@ -20,41 +20,43 @@ #include #include -template -class HashAlgorithm +namespace Crypt { -public: - typedef std::array Result; + template + class HashAlgorithm + { + public: + typedef std::array Result; - virtual ~HashAlgorithm() = default; - virtual HashAlgorithm * Clear() = 0; - virtual HashAlgorithm * Update(const void * data, size_t dataLen) = 0; - virtual Result Finish() = 0; -}; + virtual ~HashAlgorithm() = default; + virtual HashAlgorithm * Clear() = 0; + virtual HashAlgorithm * Update(const void * data, size_t dataLen) = 0; + virtual Result Finish() = 0; + }; -class RsaKey -{ -public: - virtual ~RsaKey() = default; - virtual void SetPrivate(const std::string_view& pem) = 0; - virtual void SetPublic(const std::string_view& pem) = 0; - virtual std::string GetPrivate() = 0; - virtual std::string GetPublic() = 0; -}; + class RsaKey + { + public: + virtual ~RsaKey() = default; + virtual void Generate() = 0; + virtual void SetPrivate(const std::string_view& pem) = 0; + virtual void SetPublic(const std::string_view& pem) = 0; + virtual std::string GetPrivate() = 0; + virtual std::string GetPublic() = 0; + }; -class RsaAlgorithm -{ -public: - virtual ~RsaAlgorithm() = default; - virtual std::vector SignData(const RsaKey& key, const void * data, size_t dataLen) = 0; - virtual bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) = 0; -}; + class RsaAlgorithm + { + public: + virtual ~RsaAlgorithm() = default; + virtual std::vector SignData(const RsaKey& key, const void * data, size_t dataLen) = 0; + virtual bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) = 0; + }; -using Sha1Algorithm = HashAlgorithm<20>; -using Sha256Algorithm = HashAlgorithm<32>; + using Sha1Algorithm = HashAlgorithm<20>; + using Sha256Algorithm = HashAlgorithm<32>; -namespace Hash -{ + // Factories std::unique_ptr CreateSHA1(); std::unique_ptr CreateSHA256(); std::unique_ptr CreateRSA(); diff --git a/src/openrct2/network/Network.cpp b/src/openrct2/network/Network.cpp index d4fe62525a..58f18dc0e0 100644 --- a/src/openrct2/network/Network.cpp +++ b/src/openrct2/network/Network.cpp @@ -75,8 +75,6 @@ static sint32 _pickup_peep_old_x = LOCATION_NULL; #include "NetworkAction.h" -#include // just for OpenSSL_add_all_algorithms() - #pragma comment(lib, "Ws2_32.lib") using namespace OpenRCT2; @@ -129,7 +127,6 @@ Network::Network() server_command_handlers[NETWORK_COMMAND_GAMEINFO] = &Network::Server_Handle_GAMEINFO; server_command_handlers[NETWORK_COMMAND_TOKEN] = &Network::Server_Handle_TOKEN; server_command_handlers[NETWORK_COMMAND_OBJECTS] = &Network::Server_Handle_OBJECTS; - OpenSSL_add_all_algorithms(); _chat_log_fs << std::unitbuf; _server_log_fs << std::unitbuf; @@ -668,11 +665,11 @@ bool Network::CheckSRAND(uint32 tick, uint32 srand0) server_srand0_tick = 0; // Check that the server and client sprite hashes match const char *client_sprite_hash = sprite_checksum(); - const bool sprites_mismatch = server_sprite_hash[0] != '\0' && strcmp(client_sprite_hash, server_sprite_hash) != 0; + const bool sprites_mismatch = server_sprite_hash[0] != '\0' && strcmp(client_sprite_hash, server_sprite_hash.c_str()) != 0; // Check PRNG values and sprite hashes, if exist if ((srand0 != server_srand0) || sprites_mismatch) { #ifdef DEBUG_DESYNC - dbg_report_desync(tick, srand0, server_srand0, client_sprite_hash, server_sprite_hash); + dbg_report_desync(tick, srand0, server_srand0, client_sprite_hash, server_sprite_hash.c_str()); #endif return false; } @@ -2370,13 +2367,15 @@ void Network::Client_Handle_TICK(NetworkConnection& connection, NetworkPacket& p if (server_srand0_tick == 0) { server_srand0 = srand0; server_srand0_tick = server_tick; - server_sprite_hash[0] = '\0'; + server_sprite_hash.resize(0); if (flags & NETWORK_TICK_FLAG_CHECKSUMS) { const char* text = packet.ReadString(); if (text != nullptr) { - safe_strcpy(server_sprite_hash, text, sizeof(server_sprite_hash)); + auto textLen = std::strlen(text); + server_sprite_hash.resize(textLen); + std::memcpy(server_sprite_hash.data(), text, textLen); } } } diff --git a/src/openrct2/network/NetworkKey.cpp b/src/openrct2/network/NetworkKey.cpp index 7ebe0a1a09..a6d4ca57da 100644 --- a/src/openrct2/network/NetworkKey.cpp +++ b/src/openrct2/network/NetworkKey.cpp @@ -17,78 +17,32 @@ #ifndef DISABLE_NETWORK #include -#include -#include - #include "../core/Crypt.h" #include "../core/IStream.hpp" #include "../Diagnostic.h" #include "NetworkKey.h" -#define KEY_TYPE EVP_PKEY_RSA - -constexpr sint32 KEY_LENGTH_BITS = 2048; - -NetworkKey::NetworkKey() -{ - _ctx = EVP_PKEY_CTX_new_id(KEY_TYPE, nullptr); - if (_ctx == nullptr) - { - log_error("Failed to create OpenSSL context"); - } -} - -NetworkKey::~NetworkKey() -{ - Unload(); - if (_ctx != nullptr) - { - EVP_PKEY_CTX_free(_ctx); - _ctx = nullptr; - } -} +NetworkKey::NetworkKey() { } +NetworkKey::~NetworkKey() { } void NetworkKey::Unload() { - if (_key != nullptr) - { - EVP_PKEY_free(_key); - _key = nullptr; - } + _key = nullptr; } bool NetworkKey::Generate() { - if (_ctx == nullptr) + try { - log_error("Invalid OpenSSL context"); + _key = Crypt::CreateRSAKey(); + _key->Generate(); + return true; + } + catch (const std::exception& e) + { + log_error("NetworkKey::Generate failed: %s", e.what()); return false; } -#if KEY_TYPE == EVP_PKEY_RSA - if (!EVP_PKEY_CTX_set_rsa_keygen_bits(_ctx, KEY_LENGTH_BITS)) - { - log_error("Failed to set keygen params"); - return false; - } -#else - #error Only RSA is supported! -#endif - if (EVP_PKEY_keygen_init(_ctx) <= 0) - { - log_error("Failed to initialise keygen algorithm"); - return false; - } - if (EVP_PKEY_keygen(_ctx, &_key) <= 0) - { - log_error("Failed to generate new key!"); - return false; - } - else - { - log_verbose("Key successfully generated"); - } - log_verbose("New key of type %d, length %d generated successfully.", KEY_TYPE, KEY_LENGTH_BITS); - return true; } bool NetworkKey::LoadPrivate(IStream * stream) @@ -106,34 +60,21 @@ bool NetworkKey::LoadPrivate(IStream * stream) log_error("Key file suspiciously large, refusing to load it"); return false; } - char * priv_key = new char[size]; - stream->Read(priv_key, size); - BIO * bio = BIO_new_mem_buf(priv_key, (sint32)size); - if (bio == nullptr) + + std::string pem(size, '\0'); + stream->Read(pem.data(), pem.size()); + + try { - log_error("Failed to initialise OpenSSL's BIO!"); - delete [] priv_key; + _key = Crypt::CreateRSAKey(); + _key->SetPrivate(pem); + return true; + } + catch (const std::exception& e) + { + log_error("NetworkKey::LoadPrivate failed: %s", e.what()); return false; } - RSA * rsa; - rsa = PEM_read_bio_RSAPrivateKey(bio, nullptr, nullptr, nullptr); - if (rsa == nullptr || !RSA_check_key(rsa)) - { - log_error("Loaded RSA key is invalid"); - BIO_free_all(bio); - delete [] priv_key; - return false; - } - if (_key != nullptr) - { - EVP_PKEY_free(_key); - } - _key = EVP_PKEY_new(); - EVP_PKEY_set1_RSA(_key, rsa); - BIO_free_all(bio); - RSA_free(rsa); - delete [] priv_key; - return true; } bool NetworkKey::LoadPublic(IStream * stream) @@ -151,152 +92,68 @@ bool NetworkKey::LoadPublic(IStream * stream) log_error("Key file suspiciously large, refusing to load it"); return false; } - char * pub_key = new char[size]; - stream->Read(pub_key, size); - BIO * bio = BIO_new_mem_buf(pub_key, (sint32)size); - if (bio == nullptr) + + std::string pem(size, '\0'); + stream->Read(pem.data(), pem.size()); + + try { - log_error("Failed to initialise OpenSSL's BIO!"); - delete [] pub_key; + _key = Crypt::CreateRSAKey(); + _key->SetPublic(pem); + return true; + } + catch (const std::exception& e) + { + log_error("NetworkKey::LoadPublic failed: %s", e.what()); return false; } - RSA * rsa; - rsa = PEM_read_bio_RSAPublicKey(bio, nullptr, nullptr, nullptr); - if (_key != nullptr) - { - EVP_PKEY_free(_key); - } - _key = EVP_PKEY_new(); - EVP_PKEY_set1_RSA(_key, rsa); - BIO_free_all(bio); - RSA_free(rsa); - delete [] pub_key; - return true; } bool NetworkKey::SavePrivate(IStream * stream) { - if (_key == nullptr) + try { - log_error("No key loaded"); + if (_key == nullptr) + { + throw std::exception("No key loaded"); + } + auto pem = _key->GetPrivate(); + stream->Write(pem.data(), pem.size()); + return true; + } + catch (const std::exception& e) + { + log_error("NetworkKey::SavePrivate failed: %s", e.what()); return false; } -#if KEY_TYPE == EVP_PKEY_RSA - RSA * rsa = EVP_PKEY_get1_RSA(_key); - if (rsa == nullptr) - { - log_error("Failed to get RSA key handle!"); - return false; - } - if (!RSA_check_key(rsa)) - { - log_error("Loaded RSA key is invalid"); - return false; - } - BIO * bio = BIO_new(BIO_s_mem()); - if (bio == nullptr) - { - log_error("Failed to initialise OpenSSL's BIO!"); - return false; - } - sint32 result = PEM_write_bio_RSAPrivateKey(bio, rsa, nullptr, nullptr, 0, nullptr, nullptr); - if (result != 1) - { - log_error("failed to write private key!"); - BIO_free_all(bio); - return false; - } - RSA_free(rsa); - - sint32 keylen = BIO_pending(bio); - char * pem_key = new char[keylen]; - BIO_read(bio, pem_key, keylen); - stream->Write(pem_key, keylen); - log_verbose("saving key of length %u", keylen); - BIO_free_all(bio); - delete [] pem_key; -#else - #error Only RSA is supported! -#endif - - return true; } bool NetworkKey::SavePublic(IStream * stream) { - if (_key == nullptr) + try { - log_error("No key loaded"); + if (_key == nullptr) + { + throw std::exception("No key loaded"); + } + auto pem = _key->GetPrivate(); + stream->Write(pem.data(), pem.size()); + return true; + } + catch (const std::exception& e) + { + log_error("NetworkKey::SavePublic failed: %s", e.what()); return false; } - RSA * rsa = EVP_PKEY_get1_RSA(_key); - if (rsa == nullptr) - { - log_error("Failed to get RSA key handle!"); - return false; - } - BIO * bio = BIO_new(BIO_s_mem()); - if (bio == nullptr) - { - log_error("Failed to initialise OpenSSL's BIO!"); - return false; - } - sint32 result = PEM_write_bio_RSAPublicKey(bio, rsa); - if (result != 1) - { - log_error("failed to write private key!"); - BIO_free_all(bio); - return false; - } - RSA_free(rsa); - - sint32 keylen = BIO_pending(bio); - char * pem_key = new char[keylen]; - BIO_read(bio, pem_key, keylen); - stream->Write(pem_key, keylen); - BIO_free_all(bio); - delete [] pem_key; - - return true; } std::string NetworkKey::PublicKeyString() { if (_key == nullptr) { - log_error("No key loaded"); - return nullptr; + throw std::exception("No key loaded"); } - RSA * rsa = EVP_PKEY_get1_RSA(_key); - if (rsa == nullptr) - { - log_error("Failed to get RSA key handle!"); - return nullptr; - } - BIO * bio = BIO_new(BIO_s_mem()); - if (bio == nullptr) - { - log_error("Failed to initialise OpenSSL's BIO!"); - return nullptr; - } - sint32 result = PEM_write_bio_RSAPublicKey(bio, rsa); - if (result != 1) - { - log_error("failed to write private key!"); - BIO_free_all(bio); - return nullptr; - } - RSA_free(rsa); - - sint32 keylen = BIO_pending(bio); - char * pem_key = new char[keylen + 1]; - BIO_read(bio, pem_key, keylen); - BIO_free_all(bio); - pem_key[keylen] = '\0'; - std::string pem_key_out(pem_key); - delete [] pem_key; - - return pem_key_out; + return _key->GetPublic(); } /** @@ -319,7 +176,7 @@ std::string NetworkKey::PublicKeyHash() { throw std::runtime_error("No key found"); } - auto hash = Hash::SHA1(key.c_str(), key.size()); + auto hash = Crypt::SHA1(key.c_str(), key.size()); std::string result; result.reserve(hash.size() * 2); @@ -331,7 +188,7 @@ std::string NetworkKey::PublicKeyHash() } return result; } - catch (std::exception& e) + catch (const std::exception& e) { log_error("Failed to create hash of public key: %s", e.what()); } @@ -340,100 +197,34 @@ std::string NetworkKey::PublicKeyHash() bool NetworkKey::Sign(const uint8 * md, const size_t len, char ** signature, size_t * out_size) { - EVP_MD_CTX * mdctx = nullptr; - - *signature = nullptr; - - /* Create the Message Digest Context */ - if ((mdctx = EVP_MD_CTX_create()) == nullptr) + try { - log_error("Failed to create MD context"); - return false; + auto rsa = Crypt::CreateRSA(); + auto sig = rsa->SignData(*_key, md, len); + *out_size = sig.size(); + *signature = new char[sig.size()]; + std::memcpy(*signature, sig.data(), sig.size()); + return true; } - /* Initialise the DigestSign operation - SHA-256 has been selected as the message digest function in this example */ - if (1 != EVP_DigestSignInit(mdctx, nullptr, EVP_sha256(), nullptr, _key)) + catch (const std::exception& e) { - log_error("Failed to init digest sign"); - EVP_MD_CTX_destroy(mdctx); + log_error("NetworkKey::Sign failed: %s", e.what()); + *signature = nullptr; + *out_size = 0; return false; } - /* Call update with the message */ - if (1 != EVP_DigestSignUpdate(mdctx, md, len)) - { - log_error("Failed to goto update digest"); - EVP_MD_CTX_destroy(mdctx); - return false; - } - - /* Finalise the DigestSign operation */ - /* First call EVP_DigestSignFinal with a nullptr sig parameter to obtain the length of the - * signature. Length is returned in slen */ - if (1 != EVP_DigestSignFinal(mdctx, nullptr, out_size)) - { - log_error("failed to finalise signature"); - EVP_MD_CTX_destroy(mdctx); - return false; - } - - uint8 * sig; - /* Allocate memory for the signature based on size in slen */ - if ((sig = (unsigned char*)malloc((sint32)(sizeof(unsigned char) * (*out_size)))) == nullptr) - { - log_error("Failed to crypto-allocate space for signature"); - EVP_MD_CTX_destroy(mdctx); - return false; - } - /* Obtain the signature */ - if (1 != EVP_DigestSignFinal(mdctx, sig, out_size)) { - log_error("Failed to finalise signature"); - EVP_MD_CTX_destroy(mdctx); - free(sig); - return false; - } - *signature = new char[*out_size]; - memcpy(*signature, sig, *out_size); - free(sig); - EVP_MD_CTX_destroy(mdctx); - - return true; } bool NetworkKey::Verify(const uint8 * md, const size_t len, const char * sig, const size_t siglen) { - EVP_MD_CTX * mdctx = nullptr; - - /* Create the Message Digest Context */ - if ((mdctx = EVP_MD_CTX_create()) == nullptr) + try { - log_error("Failed to create MD context"); - return false; + auto rsa = Crypt::CreateRSA(); + return rsa->VerifyData(*_key, md, len, sig, siglen); } - - if (1 != EVP_DigestVerifyInit(mdctx, nullptr, EVP_sha256(), nullptr, _key)) + catch (const std::exception& e) { - log_error("Failed to initialise verification routine"); - EVP_MD_CTX_destroy(mdctx); - return false; - } - - /* Initialize `key` with a public key */ - if (1 != EVP_DigestVerifyUpdate(mdctx, md, len)) - { - log_error("Failed to update verification"); - EVP_MD_CTX_destroy(mdctx); - return false; - } - - if (1 == EVP_DigestVerifyFinal(mdctx, (uint8 *)sig, siglen)) - { - EVP_MD_CTX_destroy(mdctx); - log_verbose("Successfully verified signature"); - return true; - } - else - { - EVP_MD_CTX_destroy(mdctx); - log_error("Signature is invalid"); + log_error("NetworkKey::Verify failed: %s", e.what()); return false; } } diff --git a/src/openrct2/network/NetworkKey.h b/src/openrct2/network/NetworkKey.h index eb5e9275f4..4bc274d29f 100644 --- a/src/openrct2/network/NetworkKey.h +++ b/src/openrct2/network/NetworkKey.h @@ -20,14 +20,16 @@ #ifndef DISABLE_NETWORK #include "../common.h" +#include #include -#include - -using EVP_PKEY = evp_pkey_st; -using EVP_PKEY_CTX = evp_pkey_ctx_st; interface IStream; +namespace Crypt +{ + class RsaKey; +} + class NetworkKey final { public: @@ -45,8 +47,7 @@ public: bool Verify(const uint8 * md, const size_t len, const char * sig, const size_t siglen); private: NetworkKey (const NetworkKey &) = delete; - EVP_PKEY_CTX * _ctx = nullptr; - EVP_PKEY * _key = nullptr; + std::unique_ptr _key; }; #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/network.h b/src/openrct2/network/network.h index be0a20aca3..438a15607d 100644 --- a/src/openrct2/network/network.h +++ b/src/openrct2/network/network.h @@ -64,7 +64,6 @@ namespace OpenRCT2 #include #include #include -#include #include "../actions/GameAction.h" #include "../core/Json.hpp" #include "../core/Nullable.hpp" @@ -242,7 +241,7 @@ private: uint32 server_tick = 0; uint32 server_srand0 = 0; uint32 server_srand0_tick = 0; - char server_sprite_hash[EVP_MAX_MD_SIZE + 1]{}; + std::string server_sprite_hash; uint8 player_id = 0; std::list> client_connection_list; std::multiset game_command_queue; diff --git a/src/openrct2/world/Sprite.cpp b/src/openrct2/world/Sprite.cpp index dc10708cf4..d8e655d1f7 100644 --- a/src/openrct2/world/Sprite.cpp +++ b/src/openrct2/world/Sprite.cpp @@ -211,6 +211,8 @@ static size_t GetSpatialIndexOffset(sint32 x, sint32 y) const char * sprite_checksum() { + using namespace Crypt; + // TODO Remove statics, should be one of these per sprite manager / OpenRCT2 context. // Alternatively, make a new class for this functionality. static std::unique_ptr> _spriteHashAlg; @@ -220,7 +222,7 @@ const char * sprite_checksum() { if (_spriteHashAlg == nullptr) { - _spriteHashAlg = Hash::CreateSHA1(); + _spriteHashAlg = CreateSHA1(); } _spriteHashAlg->Clear(); diff --git a/test/tests/CryptTests.cpp b/test/tests/CryptTests.cpp index f5b3318adb..b609873d13 100644 --- a/test/tests/CryptTests.cpp +++ b/test/tests/CryptTests.cpp @@ -34,7 +34,7 @@ public: TEST_F(CryptTests, SHA1_Basic) { std::string input = "The quick brown fox jumped over the lazy dog."; - auto result = Hash::SHA1(input.data(), input.size()); + auto result = Crypt::SHA1(input.data(), input.size()); AssertHash("c0854fb9fb03c41cce3802cb0d220529e6eef94e", result); } @@ -46,7 +46,7 @@ TEST_F(CryptTests, SHA1_Multiple) "This balloon from Balloon Stall 1 is really good value" }; - auto alg = Hash::CreateSHA1(); + auto alg = Crypt::CreateSHA1(); for (auto s : input) { alg->Update(s.data(), s.size()); @@ -61,7 +61,7 @@ TEST_F(CryptTests, SHA1_WithClear) std::string inputA = "Merry-go-round 2 looks too intense for me"; std::string inputB = "This park is really clean and tidy"; - auto alg = Hash::CreateSHA1(); + auto alg = Crypt::CreateSHA1(); alg->Update(inputA.data(), inputA.size()); alg->Clear(); alg->Update(inputB.data(), inputB.size()); @@ -70,7 +70,7 @@ TEST_F(CryptTests, SHA1_WithClear) TEST_F(CryptTests, SHA1_Many) { - auto alg = Hash::CreateSHA1(); + auto alg = Crypt::CreateSHA1(); // First digest std::string inputA[] = { @@ -104,10 +104,10 @@ TEST_F(CryptTests, RSA_Basic) std::vector data = { 0, 1, 2, 3, 4, 5, 6, 7 }; auto file = File::ReadAllBytes("C:/Users/Ted/Documents/OpenRCT2/keys/Ted.privkey"); - auto key = Hash::CreateRSAKey(); + auto key = Crypt::CreateRSAKey(); key->SetPrivate(std::string_view((const char *)file.data(), file.size())); - auto rsa = Hash::CreateRSA(); + auto rsa = Crypt::CreateRSA(); auto signature = rsa->SignData(*key, data.data(), data.size()); bool verified = rsa->VerifyData(*key, data.data(), data.size(), signature.data(), signature.size()); ASSERT_TRUE(verified); @@ -118,14 +118,14 @@ TEST_F(CryptTests, RSA_VerifyWithPublic) std::vector data = { 7, 6, 5, 4, 3, 2, 1, 0 }; auto privateFile = File::ReadAllBytes("C:/Users/Ted/Documents/OpenRCT2/keys/Ted.privkey"); - auto privateKey = Hash::CreateRSAKey(); + auto privateKey = Crypt::CreateRSAKey(); privateKey->SetPrivate(std::string_view((const char *)privateFile.data(), privateFile.size())); auto publicFile = File::ReadAllBytes("C:/Users/Ted/Documents/OpenRCT2/keys/Ted-b298a310905df8865788bdc864560c3d4c3ba562.pubkey"); - auto publicKey = Hash::CreateRSAKey(); + auto publicKey = Crypt::CreateRSAKey(); publicKey->SetPublic(std::string_view((const char *)publicFile.data(), publicFile.size())); - auto rsa = Hash::CreateRSA(); + auto rsa = Crypt::CreateRSA(); auto signature = rsa->SignData(*privateKey, data.data(), data.size()); bool verified = rsa->VerifyData(*publicKey, data.data(), data.size(), signature.data(), signature.size()); ASSERT_TRUE(verified); @@ -134,8 +134,30 @@ TEST_F(CryptTests, RSA_VerifyWithPublic) TEST_F(CryptTests, RSAKey_GetPublic) { auto inPem = File::ReadAllText("C:/Users/Ted/Documents/OpenRCT2/keys/Ted-b298a310905df8865788bdc864560c3d4c3ba562.pubkey"); - auto publicKey = Hash::CreateRSAKey(); + auto publicKey = Crypt::CreateRSAKey(); publicKey->SetPublic(inPem); auto outPem = publicKey->GetPublic(); ASSERT_EQ(inPem, outPem); } + +TEST_F(CryptTests, RSAKey_Generate) +{ + auto key = Crypt::CreateRSAKey(); + + // Test generate twice, first checking if the PEMs contain expected strings + key->Generate(); + auto privatePem1 = key->GetPrivate(); + auto publicPem1 = key->GetPublic(); + ASSERT_NE(privatePem1.find("RSA PRIVATE KEY"), std::string::npos); + ASSERT_NE(publicPem1.find("RSA PUBLIC KEY"), std::string::npos); + + key->Generate(); + auto privatePem2 = key->GetPrivate(); + auto publicPem2 = key->GetPublic(); + ASSERT_NE(privatePem2.find("RSA PRIVATE KEY"), std::string::npos); + ASSERT_NE(publicPem2.find("RSA PUBLIC KEY"), std::string::npos); + + // Now check that generate gives a different key each time + ASSERT_STRNE(privatePem1.c_str(), privatePem2.c_str()); + ASSERT_STRNE(publicPem1.c_str(), publicPem2.c_str()); +}