From de0e6bf521ea9124ea8102892a4827da18f9958b Mon Sep 17 00:00:00 2001 From: Ted John Date: Wed, 8 Feb 2017 12:53:00 +0000 Subject: [PATCH] Use IStream for network code --- src/openrct2/network/NetworkKey.cpp | 39 ++++--- src/openrct2/network/NetworkKey.h | 11 +- src/openrct2/network/TcpSocket.cpp | 48 +++++++- src/openrct2/network/TcpSocket.h | 3 + src/openrct2/network/network.cpp | 163 +++++++++++++++------------- 5 files changed, 166 insertions(+), 98 deletions(-) diff --git a/src/openrct2/network/NetworkKey.cpp b/src/openrct2/network/NetworkKey.cpp index 51dffa1e66..d2e4c8e843 100644 --- a/src/openrct2/network/NetworkKey.cpp +++ b/src/openrct2/network/NetworkKey.cpp @@ -16,13 +16,14 @@ #ifndef DISABLE_NETWORK -#include "NetworkKey.h" -#include "../diagnostic.h" - -#include -#include -#include #include +#include +#include +#include + +#include "../core/IStream.hpp" +#include "../diagnostic.h" +#include "NetworkKey.h" #define KEY_TYPE EVP_PKEY_RSA @@ -90,10 +91,11 @@ bool NetworkKey::Generate() return true; } -bool NetworkKey::LoadPrivate(SDL_RWops * file) +bool NetworkKey::LoadPrivate(IStream * stream) { - assert(file != nullptr); - size_t size = (size_t)file->size(file); + Guard::ArgumentNotNull(stream); + + size_t size = (size_t)stream->GetLength(); if (size == (size_t)-1) { log_error("unknown size, refusing to load key"); @@ -105,7 +107,7 @@ bool NetworkKey::LoadPrivate(SDL_RWops * file) return false; } char * priv_key = new char[size]; - file->read(file, priv_key, 1, size); + stream->Read(priv_key, size); BIO * bio = BIO_new_mem_buf(priv_key, (sint32)size); if (bio == nullptr) { @@ -134,10 +136,11 @@ bool NetworkKey::LoadPrivate(SDL_RWops * file) return true; } -bool NetworkKey::LoadPublic(SDL_RWops * file) +bool NetworkKey::LoadPublic(IStream * stream) { - assert(file != nullptr); - size_t size = (size_t)file->size(file); + Guard::ArgumentNotNull(stream); + + size_t size = (size_t)stream->GetLength(); if (size == (size_t)-1) { log_error("unknown size, refusing to load key"); @@ -149,7 +152,7 @@ bool NetworkKey::LoadPublic(SDL_RWops * file) return false; } char * pub_key = new char[size]; - file->read(file, pub_key, 1, size); + stream->Read(pub_key, size); BIO * bio = BIO_new_mem_buf(pub_key, (sint32)size); if (bio == nullptr) { @@ -171,7 +174,7 @@ bool NetworkKey::LoadPublic(SDL_RWops * file) return true; } -bool NetworkKey::SavePrivate(SDL_RWops *file) +bool NetworkKey::SavePrivate(IStream * stream) { if (_key == nullptr) { @@ -208,7 +211,7 @@ bool NetworkKey::SavePrivate(SDL_RWops *file) sint32 keylen = BIO_pending(bio); char * pem_key = new char[keylen]; BIO_read(bio, pem_key, keylen); - file->write(file, pem_key, keylen, 1); + stream->Write(pem_key, keylen); log_verbose("saving key of length %u", keylen); BIO_free_all(bio); delete [] pem_key; @@ -219,7 +222,7 @@ bool NetworkKey::SavePrivate(SDL_RWops *file) return true; } -bool NetworkKey::SavePublic(SDL_RWops *file) +bool NetworkKey::SavePublic(IStream * stream) { if (_key == nullptr) { @@ -250,7 +253,7 @@ bool NetworkKey::SavePublic(SDL_RWops *file) sint32 keylen = BIO_pending(bio); char * pem_key = new char[keylen]; BIO_read(bio, pem_key, keylen); - file->write(file, pem_key, keylen, 1); + stream->Write(pem_key, keylen); BIO_free_all(bio); delete [] pem_key; diff --git a/src/openrct2/network/NetworkKey.h b/src/openrct2/network/NetworkKey.h index 1ccd305778..3955ee789b 100644 --- a/src/openrct2/network/NetworkKey.h +++ b/src/openrct2/network/NetworkKey.h @@ -21,22 +21,23 @@ #include "../common.h" -#include #include typedef struct evp_pkey_st EVP_PKEY; typedef struct evp_pkey_ctx_st EVP_PKEY_CTX; +interface IStream; + class NetworkKey final { public: NetworkKey(); ~NetworkKey(); bool Generate(); - bool LoadPrivate(SDL_RWops * file); - bool LoadPublic(SDL_RWops * file); - bool SavePrivate(SDL_RWops * file); - bool SavePublic(SDL_RWops * file); + bool LoadPrivate(IStream * stream); + bool LoadPublic(IStream * stream); + bool SavePrivate(IStream * stream); + bool SavePublic(IStream * stream); std::string PublicKeyString(); std::string PublicKeyHash(); void Unload(); diff --git a/src/openrct2/network/TcpSocket.cpp b/src/openrct2/network/TcpSocket.cpp index bdc58f5500..d593f96fec 100644 --- a/src/openrct2/network/TcpSocket.cpp +++ b/src/openrct2/network/TcpSocket.cpp @@ -66,6 +66,10 @@ constexpr uint32 CONNECT_TIMEOUT_MS = 3000; +#ifdef __WINDOWS__ + static bool _wsaInitialised = false; +#endif + class TcpSocket; class SocketException : public Exception @@ -431,7 +435,7 @@ private: explicit TcpSocket(SOCKET socket) { _socket = socket; - _status = SOCKET_STATUS_CONNECTED; + _status = SOCKET_STATUS_CONNECTED; } void CloseSocket() @@ -492,4 +496,46 @@ ITcpSocket * CreateTcpSocket() return new TcpSocket(); } +bool InitialiseWSA() +{ +#ifdef __WINDOWS__ + if (!_wsaInitialised) + { + log_verbose("Initialising WSA"); + WSADATA wsa_data; + if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0) + { + log_error("Unable to initialise winsock."); + return false; + } + _wsaInitialised = true; + } + return _wsaInitialised; +#endif +} + +void DisposeWSA() +{ +#ifdef __WINDOWS__ + if (_wsaInitialised) + { + WSACleanup(); + _wsaInitialised = false; + } +#endif +} + +namespace Convert +{ + uint16 HostToNetwork(uint16 value) + { + return htons(value); + } + + uint16 NetworkToHost(uint16 value) + { + return ntohs(value); + } +} + #endif diff --git a/src/openrct2/network/TcpSocket.h b/src/openrct2/network/TcpSocket.h index 1e7f9e0320..c9cf7379ce 100644 --- a/src/openrct2/network/TcpSocket.h +++ b/src/openrct2/network/TcpSocket.h @@ -62,3 +62,6 @@ public: }; ITcpSocket * CreateTcpSocket(); + +bool InitialiseWSA(); +void DisposeWSA(); diff --git a/src/openrct2/network/network.cpp b/src/openrct2/network/network.cpp index 3dcf61a75a..3ff79f2c21 100644 --- a/src/openrct2/network/network.cpp +++ b/src/openrct2/network/network.cpp @@ -16,19 +16,12 @@ #include -#ifdef __WINDOWS__ - // winsock2 must be included before windows.h - #include -#else - #include -#endif - #include "../core/Guard.hpp" +#include "../OpenRCT2.h" extern "C" { -#include "../OpenRCT2.h" -#include "../platform/platform.h" -#include "../util/sawyercoding.h" + #include "../platform/platform.h" + #include "../util/sawyercoding.h" } #include "network.h" @@ -45,8 +38,10 @@ sint32 _pickup_peep_old_x = SPRITE_LOCATION_NULL; #include #include "../core/Console.hpp" +#include "../core/FileStream.hpp" #include "../core/Json.hpp" #include "../core/Math.hpp" +#include "../core/MemoryStream.h" #include "../core/Path.hpp" #include "../core/String.hpp" #include "../core/Util.hpp" @@ -130,17 +125,9 @@ Network::~Network() bool Network::Init() { -#ifdef __WINDOWS__ - if (!wsa_initialized) { - log_verbose("Initialising WSA"); - WSADATA wsa_data; - if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0) { - log_error("Unable to initialise winsock."); - return false; - } - wsa_initialized = true; + if (!InitialiseWSA()) { + return false; } -#endif status = NETWORK_STATUS_READY; @@ -183,12 +170,7 @@ void Network::Close() player_list.clear(); group_list.clear(); -#ifdef __WINDOWS__ - if (wsa_initialized) { - WSACleanup(); - wsa_initialized = false; - } -#endif + DisposeWSA(); CloseChatLog(); gfx_invalidate_screen(); @@ -229,36 +211,47 @@ bool Network::BeginClient(const char* host, uint16 port) return false; } - SDL_RWops *privkey = SDL_RWFromFile(keyPath, "wb+"); - if (privkey == nullptr) { + try + { + auto fs = FileStream(keyPath, FILE_MODE_WRITE); + _key.SavePrivate(&fs); + } + catch (Exception) + { log_error("Unable to save private key at %s.", keyPath); return false; } - _key.SavePrivate(privkey); - SDL_RWclose(privkey); const std::string hash = _key.PublicKeyHash(); const utf8 *publicKeyHash = hash.c_str(); network_get_public_key_path(keyPath, sizeof(keyPath), gConfigNetwork.player_name, publicKeyHash); Console::WriteLine("Key generated, saving public bits as %s", keyPath); - SDL_RWops *pubkey = SDL_RWFromFile(keyPath, "wb+"); - if (pubkey == nullptr) { + + try + { + auto fs = FileStream(keyPath, FILE_MODE_WRITE); + _key.SavePublic(&fs); + } + catch (Exception) + { log_error("Unable to save public key at %s.", keyPath); return false; } - _key.SavePublic(pubkey); - SDL_RWclose(pubkey); } else { - log_verbose("Loading key from %s", keyPath); - SDL_RWops *privkey = SDL_RWFromFile(keyPath, "rb"); - if (privkey == nullptr) { + // LoadPrivate returns validity of loaded key + bool ok = false; + try + { + log_verbose("Loading key from %s", keyPath); + auto fs = FileStream(keyPath, FILE_MODE_OPEN); + ok = _key.LoadPrivate(&fs); + } + catch (Exception) + { log_error("Unable to read private key from %s.", keyPath); return false; } - // LoadPrivate returns validity of loaded key - bool ok = _key.LoadPrivate(privkey); - SDL_RWclose(privkey); // Don't store private key in memory when it's not in use. _key.Unload(); return ok; @@ -1381,15 +1374,23 @@ void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& log_error("Key file (%s) was not found. Restart client to re-generate it.", keyPath); return; } - SDL_RWops *privkey = SDL_RWFromFile(keyPath, "rb"); - bool ok = _key.LoadPrivate(privkey); - SDL_RWclose(privkey); - if (!ok) { + + try + { + auto fs = FileStream(keyPath, FILE_MODE_OPEN); + if (!_key.LoadPrivate(&fs)) + { + throw Exception(); + } + } + catch (Exception) + { log_error("Failed to load key %s", keyPath); connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); connection.Socket->Disconnect(); return; } + uint32 challenge_size; packet >> challenge_size; const char *challenge = (const char *)packet.Read(challenge_size); @@ -1398,7 +1399,7 @@ void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& const std::string pubkey = _key.PublicKeyString(); _challenge.resize(challenge_size); memcpy(_challenge.data(), challenge, challenge_size); - ok = _key.Sign(_challenge.data(), _challenge.size(), &signature, &sigsize); + bool ok = _key.Sign(_challenge.data(), _challenge.size(), &signature, &sigsize); if (!ok) { log_error("Failed to sign server's challenge."); connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); @@ -1552,26 +1553,45 @@ void Network::Server_Handle_AUTH(NetworkConnection& connection, NetworkPacket& p if (pubkey == nullptr) { connection.AuthStatus = NETWORK_AUTH_VERIFICATIONFAILURE; } else { - const char *signature = (const char *)packet.Read(sigsize); - SDL_RWops *pubkey_rw = SDL_RWFromConstMem(pubkey, (sint32)strlen(pubkey)); - if (signature == nullptr || pubkey_rw == nullptr) { - connection.AuthStatus = NETWORK_AUTH_VERIFICATIONFAILURE; - log_verbose("Signature verification failed, invalid data!"); - } else { - connection.Key.LoadPublic(pubkey_rw); - SDL_RWclose(pubkey_rw); + try + { + const char *signature = (const char *)packet.Read(sigsize); + if (signature == nullptr) + { + throw Exception(); + } + + auto ms = MemoryStream(pubkey, strlen(pubkey)); + if (!connection.Key.LoadPublic(&ms)) + { + throw Exception(); + } + bool verified = connection.Key.Verify(connection.Challenge.data(), connection.Challenge.size(), signature, sigsize); const std::string hash = connection.Key.PublicKeyHash(); - if (verified) { - connection.AuthStatus = NETWORK_AUTH_VERIFIED; + if (verified) + { log_verbose("Signature verification ok. Hash %s", hash.c_str()); - } else { + if (gConfigNetwork.known_keys_only && _userManager.GetUserByHash(hash) == nullptr) + { + log_verbose("Hash %s, not known", hash.c_str()); + connection.AuthStatus = NETWORK_AUTH_UNKNOWN_KEY_DISALLOWED; + } + else + { + connection.AuthStatus = NETWORK_AUTH_VERIFIED; + } + } + else + { connection.AuthStatus = NETWORK_AUTH_VERIFICATIONFAILURE; log_verbose("Signature verification failed!"); } - if (gConfigNetwork.known_keys_only && _userManager.GetUserByHash(hash) == nullptr) { - connection.AuthStatus = NETWORK_AUTH_UNKNOWN_KEY_DISALLOWED; - } + } + catch (Exception) + { + connection.AuthStatus = NETWORK_AUTH_VERIFICATIONFAILURE; + log_verbose("Signature verification failed, invalid data!"); } } @@ -1955,19 +1975,6 @@ void Network::Client_Handle_GAMEINFO(NetworkConnection& connection, NetworkPacke network_chat_show_server_greeting(); } -namespace Convert -{ - uint16 HostToNetwork(uint16 value) - { - return htons(value); - } - - uint16 NetworkToHost(uint16 value) - { - return ntohs(value); - } -} - sint32 network_init() { return gNetwork.Init(); @@ -2456,8 +2463,16 @@ void network_send_password(const char* password) log_error("Private key %s missing! Restart the game to generate it.", keyPath); return; } - SDL_RWops *privkey = SDL_RWFromFile(keyPath, "rb"); - gNetwork._key.LoadPrivate(privkey); + try + { + auto fs = FileStream(keyPath, FILE_MODE_OPEN); + gNetwork._key.LoadPrivate(&fs); + } + catch (Exception) + { + log_error("Error reading private key from %s.", keyPath); + return; + } const std::string pubkey = gNetwork._key.PublicKeyString(); size_t sigsize; char *signature;