/***************************************************************************** * Copyright (c) 2014-2020 OpenRCT2 developers * * For a complete list of all authors, please refer to contributors.md * Interested in contributing? Visit https://github.com/OpenRCT2/OpenRCT2 * * OpenRCT2 is licensed under the GNU General Public License version 3. *****************************************************************************/ #if !defined(DISABLE_NETWORK) && defined(_WIN32) && (!defined(_WIN32_WINNT) || _WIN32_WINNT >= 0x0600) # include "../platform/Platform2.h" # include "Crypt.h" # include "IStream.hpp" # include # include # include # include // clang-format off // CNG: Cryptography API: Next Generation (CNG) // available in Windows Vista onwards. #include #include #include constexpr bool NT_SUCCESS(NTSTATUS status) {return status >= 0;} // clang-format on using namespace Crypt; static void CngThrowOnBadStatus(std::string_view name, NTSTATUS status) { if (!NT_SUCCESS(status)) { throw std::runtime_error(std::string(name) + " failed: " + std::to_string(status)); } } static void ThrowBadAllocOnNull(const void* ptr) { if (ptr == nullptr) { throw std::bad_alloc(); } } template class CngHashAlgorithm final : public TBase { private: const wchar_t* _algName; BCRYPT_ALG_HANDLE _hAlg{}; BCRYPT_HASH_HANDLE _hHash{}; PBYTE _pbHashObject{}; bool _reusable{}; public: CngHashAlgorithm(const wchar_t* algName) { // BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8 _algName = algName; _reusable = Platform::IsOSVersionAtLeast(6, 2, 0); Initialise(); } ~CngHashAlgorithm() { Dispose(); } TBase* Clear() override { if (_reusable) { // Finishing the current digest clears the state ready for a new digest Finish(); } else { Dispose(); Initialise(); } return this; } TBase* Update(const void* data, size_t dataLen) override { auto status = BCryptHashData(_hHash, reinterpret_cast(const_cast(data)), static_cast(dataLen), 0); CngThrowOnBadStatus("BCryptHashData", status); return this; } typename TBase::Result Finish() override { typename TBase::Result result; auto status = BCryptFinishHash(_hHash, result.data(), static_cast(result.size()), 0); CngThrowOnBadStatus("BCryptFinishHash", status); return result; } private: void Initialise() { auto flags = _reusable ? BCRYPT_HASH_REUSABLE_FLAG : 0; auto status = BCryptOpenAlgorithmProvider(&_hAlg, _algName, nullptr, flags); CngThrowOnBadStatus("BCryptOpenAlgorithmProvider", status); // Calculate the size of the buffer to hold the hash object DWORD cbHashObject{}; DWORD cbData{}; status = BCryptGetProperty( _hAlg, BCRYPT_OBJECT_LENGTH, reinterpret_cast(&cbHashObject), sizeof(DWORD), &cbData, 0); CngThrowOnBadStatus("BCryptGetProperty", status); // Create a hash _pbHashObject = reinterpret_cast(HeapAlloc(GetProcessHeap(), 0, cbHashObject)); ThrowBadAllocOnNull(_pbHashObject); status = BCryptCreateHash(_hAlg, &_hHash, _pbHashObject, cbHashObject, nullptr, 0, 0); CngThrowOnBadStatus("BCryptCreateHash", status); } void Dispose() { BCryptCloseAlgorithmProvider(_hAlg, 0); BCryptDestroyHash(_hHash); HeapFree(GetProcessHeap(), 0, _pbHashObject); _hAlg = {}; _hHash = {}; _pbHashObject = {}; } }; class DerReader { private: ivstream _stream; template T Read(std::istream& stream) { T value; stream.read(reinterpret_cast(&value), sizeof(T)); return value; } template std::vector Read(std::istream& stream, size_t count) { std::vector values(count); stream.read(reinterpret_cast(values.data()), sizeof(T) * count); return values; } int ReadTag(std::istream& stream) { auto a = Read(stream); auto tagNumber = a & 0x1F; if (tagNumber == 0x1F) { throw std::runtime_error("Unsupported DER tag"); } return tagNumber; } int ReadLength(std::istream& stream) { auto a = Read(stream); auto len = a & 0x7F; if (len == a) { return len; } if (len > 6) { throw std::runtime_error("Length over 48 bits not supported at this position"); } if (len == 0) { throw std::runtime_error("Unknown length"); } auto result = 0; for (auto i = 0; i < len; i++) { result = (result << 8) + Read(stream); } return result; } public: DerReader(const std::vector& data) : _stream(data) { } void ReadSequenceHeader() { auto a = Read(_stream); if (a == 0x8130) { Read(_stream); } else if (a == 0x8230) { Read(_stream); } else { throw std::runtime_error("Invalid DER code"); } } std::vector ReadInteger() { auto t = ReadTag(_stream); if (t != 2) { throw std::runtime_error("Expected INTEGER"); } auto len = ReadLength(_stream); auto result = Read(_stream, len); auto v = result[0]; auto neg = (v > 127); auto pad = neg ? 255 : 0; for (size_t i = 0; i < result.size(); i++) { if (result[i] != pad) { result.erase(result.begin(), result.begin() + i); break; } } return result; } }; class DerWriter { private: std::vector _buffer; public: void WriteSequenceHeader(size_t len) { _buffer.push_back(0x30); WriteCompressedNumber(len); } void WriteInteger(const std::vector& data) { _buffer.push_back(0x02); size_t dataLen = data.size(); if (dataLen > 0 && data[0] > 127) { // Prepend a zero to number so it isn't treated as negative WriteCompressedNumber(dataLen + 1); _buffer.push_back(0); _buffer.insert(_buffer.end(), data.begin(), data.end()); } else { WriteCompressedNumber(dataLen); _buffer.insert(_buffer.end(), data.begin(), data.end()); } } void WriteCompressedNumber(uint64_t value) { if (value < 128) { _buffer.push_back(static_cast(value)); } else if (value <= std::numeric_limits().max()) { _buffer.push_back(0b10000001); _buffer.push_back(static_cast(value)); } else if (value <= std::numeric_limits().max()) { _buffer.push_back(0b10000010); _buffer.push_back((value >> 8) & 0xFF); _buffer.push_back(value & 0xFF); } } std::vector&& Complete() { auto oldBuffer = std::move(_buffer); WriteSequenceHeader(oldBuffer.size()); _buffer.insert(_buffer.end(), oldBuffer.begin(), oldBuffer.end()); return std::move(_buffer); } }; class CngRsaKey final : public RsaKey { private: struct RsaKeyParams { std::vector Modulus; std::vector Exponent; std::vector Prime1; std::vector Prime2; std::vector Exponent1; std::vector Exponent2; std::vector Coefficient; std::vector PrivateExponent; ULONG GetMagic() const { ULONG magic = BCRYPT_RSAPUBLIC_MAGIC; if (!Prime1.empty() || !Prime2.empty()) magic = BCRYPT_RSAPRIVATE_MAGIC; if (!Exponent1.empty()) magic = BCRYPT_RSAFULLPRIVATE_MAGIC; return magic; } size_t GetTotalSize() { return Modulus.size() + Exponent.size() + Prime1.size() + Prime2.size() + Exponent1.size() + Exponent2.size() + Coefficient.size() + PrivateExponent.size(); } static RsaKeyParams FromBlob(const std::vector& blob) { RsaKeyParams result; const auto& header = *(reinterpret_cast(blob.data())); size_t offset = sizeof(BCRYPT_RSAKEY_BLOB); result.Exponent = ReadBytes(blob, offset, header.cbPublicExp); result.Modulus = ReadBytes(blob, offset, header.cbModulus); if (header.Magic == BCRYPT_RSAPRIVATE_MAGIC || header.Magic == BCRYPT_RSAFULLPRIVATE_MAGIC) { result.Prime1 = ReadBytes(blob, offset, header.cbPrime1); result.Prime2 = ReadBytes(blob, offset, header.cbPrime2); } if (header.Magic == BCRYPT_RSAFULLPRIVATE_MAGIC) { result.Exponent1 = ReadBytes(blob, offset, header.cbPrime1); result.Exponent2 = ReadBytes(blob, offset, header.cbPrime2); result.Coefficient = ReadBytes(blob, offset, header.cbPrime1); result.PrivateExponent = ReadBytes(blob, offset, header.cbModulus); } return result; } std::vector ToBlob() const { auto magic = GetMagic(); std::vector blob(sizeof(BCRYPT_RSAKEY_BLOB)); auto& header = *(reinterpret_cast(blob.data())); header.Magic = magic; header.BitLength = static_cast(Modulus.size() * 8); header.cbPublicExp = static_cast(Exponent.size()); header.cbModulus = static_cast(Modulus.size()); header.cbPrime1 = static_cast(Prime1.size()); header.cbPrime2 = static_cast(Prime2.size()); WriteBytes(blob, Exponent); WriteBytes(blob, Modulus); if (magic == BCRYPT_RSAPRIVATE_MAGIC || magic == BCRYPT_RSAFULLPRIVATE_MAGIC) { WriteBytes(blob, Prime1); WriteBytes(blob, Prime2); } if (magic == BCRYPT_RSAFULLPRIVATE_MAGIC) { WriteBytes(blob, Exponent1); WriteBytes(blob, Exponent2); WriteBytes(blob, Coefficient); WriteBytes(blob, PrivateExponent); } return blob; } private: static std::vector ReadBytes(const std::vector& src, size_t& offset, size_t length) { std::vector result; result.insert(result.end(), src.begin() + offset, src.begin() + offset + length); offset += length; return result; } static void WriteBytes(std::vector& dst, const std::vector& src) { dst.insert(dst.end(), src.begin(), src.end()); } }; public: BCRYPT_KEY_HANDLE GetKeyHandle() const { return _hKey; } CngRsaKey() { auto status = BCryptOpenAlgorithmProvider(&_hAlg, BCRYPT_RSA_ALGORITHM, NULL, 0); CngThrowOnBadStatus("BCryptOpenAlgorithmProvider", status); } ~CngRsaKey() { BCryptDestroyKey(_hKey); BCryptCloseAlgorithmProvider(_hAlg, 0); _hKey = {}; _hAlg = {}; } void SetPrivate(std::string_view pem) override { auto der = ReadPEM(pem, SZ_PRIVATE_BEGIN_TOKEN, SZ_PRIVATE_END_TOKEN); DerReader derReader(der); RsaKeyParams params; derReader.ReadSequenceHeader(); derReader.ReadInteger(); params.Modulus = derReader.ReadInteger(); params.Exponent = derReader.ReadInteger(); params.PrivateExponent = derReader.ReadInteger(); params.Prime1 = derReader.ReadInteger(); params.Prime2 = derReader.ReadInteger(); params.Exponent1 = derReader.ReadInteger(); params.Exponent2 = derReader.ReadInteger(); params.Coefficient = derReader.ReadInteger(); ImportKey(params); } void SetPublic(std::string_view pem) override { auto der = ReadPEM(pem, SZ_PUBLIC_BEGIN_TOKEN, SZ_PUBLIC_END_TOKEN); DerReader derReader(der); RsaKeyParams params; derReader.ReadSequenceHeader(); params.Modulus = derReader.ReadInteger(); params.Exponent = derReader.ReadInteger(); ImportKey(params); } std::string GetPrivate() override { auto params = ExportKey(); DerWriter derWriter; derWriter.WriteInteger({ 0 }); derWriter.WriteInteger(params.Modulus); derWriter.WriteInteger(params.Exponent); derWriter.WriteInteger(params.PrivateExponent); derWriter.WriteInteger(params.Prime1); derWriter.WriteInteger(params.Prime2); derWriter.WriteInteger(params.Exponent1); derWriter.WriteInteger(params.Exponent2); derWriter.WriteInteger(params.Coefficient); auto derBytes = derWriter.Complete(); auto b64 = EncodeBase64(derBytes); std::ostringstream sb; sb << std::string(SZ_PRIVATE_BEGIN_TOKEN) << std::endl; sb << b64; sb << std::string(SZ_PRIVATE_END_TOKEN) << std::endl; return sb.str(); } std::string GetPublic() override { auto params = ExportKey(); DerWriter derWriter; derWriter.WriteInteger(params.Modulus); derWriter.WriteInteger(params.Exponent); auto derBytes = derWriter.Complete(); auto b64 = EncodeBase64(derBytes); std::ostringstream sb; sb << std::string(SZ_PUBLIC_BEGIN_TOKEN) << std::endl; sb << b64; sb << std::string(SZ_PUBLIC_END_TOKEN) << std::endl; return sb.str(); } void Generate() override { Reset(); try { auto status = BCryptGenerateKeyPair(_hAlg, &_hKey, 1024, 0); CngThrowOnBadStatus("BCryptGenerateKeyPair", status); status = BCryptFinalizeKeyPair(_hKey, 0); CngThrowOnBadStatus("BCryptFinalizeKeyPair", status); _keyBlobType = BCRYPT_RSAFULLPRIVATE_BLOB; } catch (const std::exception&) { Reset(); throw; } } private: static constexpr const char* SZ_PUBLIC_BEGIN_TOKEN = "-----BEGIN RSA PUBLIC KEY-----"; static constexpr const char* SZ_PUBLIC_END_TOKEN = "-----END RSA PUBLIC KEY-----"; static constexpr const char* SZ_PRIVATE_BEGIN_TOKEN = "-----BEGIN RSA PRIVATE KEY-----"; static constexpr const char* SZ_PRIVATE_END_TOKEN = "-----END RSA PRIVATE KEY-----"; BCRYPT_KEY_HANDLE _hKey{}; BCRYPT_KEY_HANDLE _hAlg{}; LPCWSTR _keyBlobType{}; void Reset() { BCryptDestroyKey(_hKey); _hKey = {}; } void ImportKey(const RsaKeyParams& params) { Reset(); auto blob = params.ToBlob(); _keyBlobType = params.GetMagic() == BCRYPT_RSAFULLPRIVATE_MAGIC ? BCRYPT_RSAFULLPRIVATE_BLOB : BCRYPT_RSAPUBLIC_BLOB; auto status = BCryptImportKeyPair(_hAlg, NULL, _keyBlobType, &_hKey, blob.data(), static_cast(blob.size()), 0); CngThrowOnBadStatus("BCryptImportKeyPair", status); } RsaKeyParams ExportKey() { ULONG cbOutput{}; auto status = BCryptExportKey(_hKey, NULL, _keyBlobType, NULL, 0, &cbOutput, 0); CngThrowOnBadStatus("BCryptExportKey", status); std::vector blob(cbOutput); status = BCryptExportKey(_hKey, NULL, _keyBlobType, blob.data(), cbOutput, &cbOutput, 0); CngThrowOnBadStatus("BCryptExportKey", status); return RsaKeyParams::FromBlob(blob); } static std::vector ReadPEM(std::string_view pem, std::string_view beginToken, std::string_view endToken) { auto beginPos = pem.find(beginToken); auto endPos = pem.find(endToken); if (beginPos != std::string::npos && endPos != std::string::npos) { beginPos += beginToken.size(); auto code = Trim(pem.substr(beginPos, endPos - beginPos)); return DecodeBase64(code); } throw std::runtime_error("Invalid PEM file"); } static std::string_view Trim(std::string_view input) { for (size_t i = 0; i < input.size(); i++) { if (input[i] >= '!') { input.remove_prefix(i); break; } } for (size_t i = input.size() - 1; i > 0; i--) { if (input[i] >= '!') { input = input.substr(0, i + 1); break; } } return input; } static std::string EncodeBase64(const std::vector& input) { DWORD flags = CRYPT_STRING_BASE64 | CRYPT_STRING_NOCR; DWORD chString{}; if (!CryptBinaryToStringA(input.data(), static_cast(input.size()), flags, NULL, &chString)) { throw std::runtime_error("CryptBinaryToStringA failed"); } std::string result(chString, 0); if (!CryptBinaryToStringA(input.data(), static_cast(input.size()), flags, result.data(), &chString)) { throw std::runtime_error("CryptBinaryToStringA failed"); } // CryptBinaryToStringA returns length that includes null terminator result.resize(result.size() - 1); return result; } static std::vector DecodeBase64(std::string_view input) { DWORD cbBinary{}; if (!CryptStringToBinaryA( input.data(), static_cast(input.size()), CRYPT_STRING_BASE64, NULL, &cbBinary, NULL, NULL)) { throw std::runtime_error("CryptStringToBinaryA failed"); } std::vector result(cbBinary); if (!CryptStringToBinaryA( input.data(), static_cast(input.size()), CRYPT_STRING_BASE64, result.data(), &cbBinary, NULL, NULL)) { throw std::runtime_error("CryptStringToBinaryA failed"); } return result; } }; class CngRsaAlgorithm final : public RsaAlgorithm { public: std::vector SignData(const RsaKey& key, const void* data, size_t dataLen) override { auto hKey = static_cast(key).GetKeyHandle(); auto [cbHash, pbHash] = HashData(data, dataLen); auto [cbSignature, pbSignature] = std::tuple(); try { BCRYPT_PKCS1_PADDING_INFO paddingInfo{ BCRYPT_SHA256_ALGORITHM }; auto status = BCryptSignHash(hKey, &paddingInfo, pbHash, cbHash, NULL, 0, &cbSignature, BCRYPT_PAD_PKCS1); CngThrowOnBadStatus("BCryptSignHash", status); pbSignature = reinterpret_cast(HeapAlloc(GetProcessHeap(), 0, cbSignature)); ThrowBadAllocOnNull(pbSignature); status = BCryptSignHash( hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, &cbSignature, BCRYPT_PAD_PKCS1); CngThrowOnBadStatus("BCryptSignHash", status); auto result = std::vector(pbSignature, pbSignature + cbSignature); HeapFree(GetProcessHeap(), 0, pbSignature); return result; } catch (std::exception&) { HeapFree(GetProcessHeap(), 0, pbHash); HeapFree(GetProcessHeap(), 0, pbSignature); throw; } } bool VerifyData(const RsaKey& key, const void* data, size_t dataLen, const void* sig, size_t sigLen) override { auto hKey = static_cast(key).GetKeyHandle(); auto [cbHash, pbHash] = HashData(data, dataLen); auto [cbSignature, pbSignature] = ToHeap(sig, sigLen); BCRYPT_PKCS1_PADDING_INFO paddingInfo{ BCRYPT_SHA256_ALGORITHM }; auto status = BCryptVerifySignature(hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, BCRYPT_PAD_PKCS1); HeapFree(GetProcessHeap(), 0, pbSignature); return status == ERROR_SUCCESS; } private: static std::tuple HashData(const void* data, size_t dataLen) { auto hash = Crypt::SHA256(data, dataLen); return ToHeap(hash.data(), hash.size()); } static std::tuple ToHeap(const void* data, size_t dataLen) { auto cbHash = static_cast(dataLen); auto pbHash = reinterpret_cast(HeapAlloc(GetProcessHeap(), 0, dataLen)); ThrowBadAllocOnNull(pbHash); std::memcpy(pbHash, data, dataLen); return std::make_tuple(cbHash, pbHash); } }; namespace Crypt { std::unique_ptr CreateSHA1() { return std::make_unique>(BCRYPT_SHA1_ALGORITHM); } std::unique_ptr CreateSHA256() { return std::make_unique>(BCRYPT_SHA256_ALGORITHM); } std::unique_ptr CreateRSA() { return std::make_unique(); } std::unique_ptr CreateRSAKey() { return std::make_unique(); } } // namespace Crypt #endif