From fe7e8a17de67ebf761caee631fd232ee84524610 Mon Sep 17 00:00:00 2001 From: Ted John Date: Fri, 25 May 2018 00:39:27 +0100 Subject: [PATCH] Start implementing RSA for OpenSSL --- src/openrct2/core/Crypt.CNG.cpp | 25 +++--- src/openrct2/core/Crypt.OpenSSL.cpp | 134 +++++++++++++++++++++++++--- src/openrct2/core/Crypt.h | 35 +++++++- 3 files changed, 167 insertions(+), 27 deletions(-) diff --git a/src/openrct2/core/Crypt.CNG.cpp b/src/openrct2/core/Crypt.CNG.cpp index 64495a33ad..385089b788 100644 --- a/src/openrct2/core/Crypt.CNG.cpp +++ b/src/openrct2/core/Crypt.CNG.cpp @@ -32,28 +32,31 @@ #include #define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) -class CNGSha1Algorithm final : public Sha1Algorithm +template +class CngHashAlgorithm final : public TBase { private: + const char * _algName; BCRYPT_ALG_HANDLE _hAlg{}; BCRYPT_HASH_HANDLE _hHash{}; PBYTE _pbHashObject{}; bool _reusable{}; public: - CNGSha1Algorithm() + CngHashAlgorithm(const char * algName) { // BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8 + _algName = algName; _reusable = Platform::IsOSVersionAtLeast(6, 2, 0); Initialise(); } - ~CNGSha1Algorithm() + ~CngHashAlgorithm() { Dispose(); } - void Clear() override + HashAlgorithm * Clear() override { if (_reusable) { @@ -65,15 +68,17 @@ public: Dispose(); Initialise(); } + return this; } - void Update(const void * data, size_t dataLen) override + HashAlgorithm * Update(const void * data, size_t dataLen) override { auto status = BCryptHashData(_hHash, (PBYTE)data, (ULONG)dataLen, 0); if (!NT_SUCCESS(status)) { throw std::runtime_error("BCryptHashData failed: " + std::to_string(status)); } + return this; } Result Finish() override @@ -91,7 +96,7 @@ private: void Initialise() { auto flags = _reusable ? BCRYPT_HASH_REUSABLE_FLAG : 0; - auto status = BCryptOpenAlgorithmProvider(&_hAlg, BCRYPT_SHA1_ALGORITHM, nullptr, flags); + auto status = BCryptOpenAlgorithmProvider(&_hAlg, TAlg, nullptr, flags); if (!NT_SUCCESS(status)) { throw std::runtime_error("BCryptOpenAlgorithmProvider failed: " + std::to_string(status)); @@ -135,14 +140,12 @@ namespace Hash { std::unique_ptr CreateSHA1() { - return std::make_unique(); + return std::make_unique>(BCRYPT_SHA1_ALGORITHM); } - Sha1Algorithm::Result SHA1(const void * data, size_t dataLen) + std::unique_ptr CreateSHA256() { - CNGSha1Algorithm sha1; - sha1.Update(data, dataLen); - return sha1.Finish(); + return std::make_unique>(BCRYPT_SHA256_ALGORITHM); } } diff --git a/src/openrct2/core/Crypt.OpenSSL.cpp b/src/openrct2/core/Crypt.OpenSSL.cpp index 19ecb33d83..5df154bb9c 100644 --- a/src/openrct2/core/Crypt.OpenSSL.cpp +++ b/src/openrct2/core/Crypt.OpenSSL.cpp @@ -23,17 +23,29 @@ #include "Crypt.h" #include #include +#include #include -class OpenSSLSha1Algorithm final : public Sha1Algorithm +static void OpenSSLThrowOnBadStatus(const std::string_view& name, int status) +{ + if (status != 1) + { + throw std::runtime_error(std::string(name) + " failed: " + std::to_string(status)); + } +} + +template +class OpenSSLHashAlgorithm final : public TBase { private: + const EVP_MD * _type; EVP_MD_CTX * _ctx{}; bool _initialised{}; public: - OpenSSLSha1Algorithm() + OpenSSLHashAlgorithm(const EVP_MD * type) { + _type = type; _ctx = EVP_MD_CTX_create(); if (_ctx == nullptr) { @@ -41,21 +53,22 @@ public: } } - ~OpenSSLSha1Algorithm() + ~OpenSSLHashAlgorithm() { EVP_MD_CTX_destroy(_ctx); } - void Clear() override + TBase * Clear() override { - if (EVP_DigestInit_ex(_ctx, EVP_sha1(), nullptr) <= 0) + if (EVP_DigestInit_ex(_ctx, _type, nullptr) <= 0) { throw std::runtime_error("EVP_DigestInit_ex failed"); } _initialised = true; + return this; } - void Update(const void * data, size_t dataLen) override + TBase * Update(const void * data, size_t dataLen) override { // Auto initialise if (!_initialised) @@ -67,9 +80,10 @@ public: { throw std::runtime_error("EVP_DigestUpdate failed"); } + return this; } - Result Finish() override + typename TBase::Result Finish() override { if (!_initialised) { @@ -77,7 +91,7 @@ public: } _initialised = false; - Result result; + typename TBase::Result result; unsigned int digestSize{}; if (EVP_DigestFinal(_ctx, result.data(), &digestSize) <= 0) { @@ -92,18 +106,112 @@ public: } }; +class OpenSSLRsaKey final : public RsaKey +{ +public: + EVP_PKEY * const EvpKey{}; + + void SetPrivate(const std::string_view& pem) override { } + void SetPublic(const std::string_view& pem) override { } + std::string GetPrivate() override { return ""; } + std::string GetPublic() override { return ""; } +}; + +class OpenSSLRsaAlgorithm final : public RsaAlgorithm +{ +public: + std::vector SignData(const RsaKey& key, const void * data, size_t dataLen) override + { + auto evpKey = static_cast(key).EvpKey; + EVP_MD_CTX * mdctx{}; + try + { + mdctx = EVP_MD_CTX_create(); + if (mdctx == nullptr) + { + throw std::runtime_error("EVP_MD_CTX_create failed"); + } + + auto status = EVP_DigestSignInit(mdctx, nullptr, EVP_sha256(), nullptr, evpKey); + OpenSSLThrowOnBadStatus("EVP_DigestSignInit failed", status); + + status = EVP_DigestSignUpdate(mdctx, data, dataLen); + OpenSSLThrowOnBadStatus("EVP_DigestSignUpdate failed", status); + + // Get required length of signature + size_t sigLen{}; + status = EVP_DigestSignFinal(mdctx, nullptr, &sigLen); + OpenSSLThrowOnBadStatus("EVP_DigestSignFinal failed", status); + + // Get signature + std::vector signature(sigLen); + status = EVP_DigestSignFinal(mdctx, signature.data(), &sigLen); + OpenSSLThrowOnBadStatus("EVP_DigestSignFinal failed", status); + + EVP_MD_CTX_destroy(mdctx); + return signature; + } + catch (const std::exception&) + { + EVP_MD_CTX_destroy(mdctx); + throw; + } + } + + bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) override + { + auto evpKey = static_cast(key).EvpKey; + EVP_MD_CTX * mdctx{}; + try + { + mdctx = EVP_MD_CTX_create(); + if (mdctx == nullptr) + { + throw std::runtime_error("EVP_MD_CTX_create failed"); + } + + auto status = EVP_DigestVerifyInit(mdctx, nullptr, EVP_sha256(), nullptr, evpKey); + OpenSSLThrowOnBadStatus("EVP_DigestVerifyInit", status); + + status = EVP_DigestVerifyUpdate(mdctx, data, dataLen); + OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status); + + status = EVP_DigestVerifyFinal(mdctx, (uint8_t*)sig, sigLen); + if (status != 0 && status != 1) + { + OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status); + } + EVP_MD_CTX_destroy(mdctx); + return status == 0; + } + catch (const std::exception&) + { + EVP_MD_CTX_destroy(mdctx); + throw; + } + } +}; + namespace Hash { std::unique_ptr CreateSHA1() { - return std::make_unique(); + return std::make_unique>(EVP_sha1()); } - Sha1Algorithm::Result SHA1(const void * data, size_t dataLen) + std::unique_ptr CreateSHA256() { - OpenSSLSha1Algorithm sha1; - sha1.Update(data, dataLen); - return sha1.Finish(); + return std::make_unique>(EVP_sha256()); + } + + std::unique_ptr CreateRSA() + { + return std::make_unique(); + } + + std::unique_ptr CreateRSAKey() + { + return std::make_unique(); } } diff --git a/src/openrct2/core/Crypt.h b/src/openrct2/core/Crypt.h index 04c5ea7e41..387898c92e 100644 --- a/src/openrct2/core/Crypt.h +++ b/src/openrct2/core/Crypt.h @@ -18,6 +18,7 @@ #include #include +#include template class HashAlgorithm @@ -26,15 +27,43 @@ public: typedef std::array Result; virtual ~HashAlgorithm() = default; - virtual void Clear() = 0; - virtual void Update(const void * data, size_t dataLen) = 0; + virtual HashAlgorithm * Clear() = 0; + virtual HashAlgorithm * Update(const void * data, size_t dataLen) = 0; virtual Result Finish() = 0; }; +class RsaKey +{ +public: + virtual ~RsaKey() = default; + virtual void SetPrivate(const std::string_view& pem) = 0; + virtual void SetPublic(const std::string_view& pem) = 0; + virtual std::string GetPrivate() = 0; + virtual std::string GetPublic() = 0; +}; + +class RsaAlgorithm +{ +public: + virtual ~RsaAlgorithm() = default; + virtual std::vector SignData(const RsaKey& key, const void * data, size_t dataLen) = 0; + virtual bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) = 0; +}; + using Sha1Algorithm = HashAlgorithm<20>; +using Sha256Algorithm = HashAlgorithm<32>; namespace Hash { std::unique_ptr CreateSHA1(); - Sha1Algorithm::Result SHA1(const void * data, size_t dataLen); + std::unique_ptr CreateSHA256(); + std::unique_ptr CreateRSA(); + std::unique_ptr CreateRSAKey(); + + Sha1Algorithm::Result SHA1(const void * data, size_t dataLen) + { + return CreateSHA1() + ->Update(data, dataLen) + ->Finish(); + } }