From 46ecd53a991e038eda2cfb177a2870e69eca7ce3 Mon Sep 17 00:00:00 2001 From: Ted John Date: Sat, 28 May 2016 16:45:09 +0100 Subject: [PATCH] refactor NetworkAddress --- openrct2.vcxproj | 7 ++ src/network/NetworkAddress.cpp | 123 +++++++++++++++++++++++++++++++++ src/network/NetworkAddress.h | 71 +++++++++++++++++++ src/network/NetworkTypes.h | 27 ++++++++ src/network/network.cpp | 96 +++++-------------------- src/network/network.h | 28 +------- 6 files changed, 248 insertions(+), 104 deletions(-) create mode 100644 src/network/NetworkAddress.cpp create mode 100644 src/network/NetworkAddress.h create mode 100644 src/network/NetworkTypes.h diff --git a/openrct2.vcxproj b/openrct2.vcxproj index f22ed97908..4d4d928b7e 100644 --- a/openrct2.vcxproj +++ b/openrct2.vcxproj @@ -83,6 +83,7 @@ + @@ -358,6 +359,8 @@ + + @@ -480,6 +483,7 @@ 4013 false true + true true @@ -487,6 +491,9 @@ UseFastLinkTimeCodeGeneration /OPT:NOLBR /ignore:4099 %(AdditionalOptions) + + true + diff --git a/src/network/NetworkAddress.cpp b/src/network/NetworkAddress.cpp new file mode 100644 index 0000000000..66b763d1ac --- /dev/null +++ b/src/network/NetworkAddress.cpp @@ -0,0 +1,123 @@ +#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers +/***************************************************************************** + * OpenRCT2, an open source clone of Roller Coaster Tycoon 2. + * + * OpenRCT2 is the work of many authors, a full list can be found in contributors.md + * For more information, visit https://github.com/OpenRCT2/OpenRCT2 + * + * OpenRCT2 is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * A full copy of the GNU General Public License can be found in licence.txt + *****************************************************************************/ +#pragma endregion + +#include +#include +#include "NetworkAddress.h" + +NetworkAddress::NetworkAddress() +{ + _result = std::make_shared(); + _result->status = RESOLVE_NONE; + _resolveMutex = SDL_CreateMutex(); +} + +NetworkAddress::~NetworkAddress() +{ + SDL_DestroyMutex(_resolveMutex); +} + +void NetworkAddress::Resolve(const char * host, uint16 port) +{ + SDL_LockMutex(_resolveMutex); + { + // Create a new result store + _result = std::make_shared(); + _result->status = RESOLVE_INPROGRESS; + + // Create a new request + auto req = new ResolveRequest(); + req->Host = std::string(host == nullptr ? "" : host); + req->Port = port;; + req->Result = _result; + + // Resolve synchronously + ResolveWorker(req); + } + SDL_UnlockMutex(_resolveMutex); +} + +void NetworkAddress::ResolveAsync(const char * host, uint16 port) +{ + SDL_LockMutex(_resolveMutex); + { + // Create a new result store + _result = std::make_shared(); + _result->status = RESOLVE_INPROGRESS; + + // Create a new request + auto req = new ResolveRequest(); + req->Host = std::string(host); + req->Port = port; + req->Result = _result; + + // Spin off a worker thread for resolving the address + SDL_CreateThread([](void * pointer) -> int + { + ResolveWorker((ResolveRequest *)pointer); + return 0; + }, 0, req); + } + SDL_UnlockMutex(_resolveMutex); +} + +NetworkAddress::RESOLVE_STATUS NetworkAddress::GetResult(sockaddr_storage * ss, int * ss_len) +{ + SDL_LockMutex(_resolveMutex); + { + const ResolveResult * result = _result.get(); + if (result->status == RESOLVE_OK) + { + *ss = result->ss; + *ss_len = result->ss_len; + } + return result->status; + } + SDL_UnlockMutex(_resolveMutex); +} + +void NetworkAddress::ResolveWorker(ResolveRequest * req) +{ + // Resolve the address + const char * nodeName = req->Host.c_str(); + std::string serviceName = std::to_string(req->Port); + + addrinfo hints = { 0 }; + hints.ai_family = AF_UNSPEC; + if (req->Host.empty()) + { + hints.ai_flags = AI_PASSIVE; + nodeName = nullptr; + } + + addrinfo * result; + getaddrinfo(nodeName, serviceName.c_str(), &hints, &result); + + // Store the result + ResolveResult * resolveResult = req->Result.get(); + if (result != nullptr) + { + resolveResult->status = RESOLVE_OK; + memcpy(&resolveResult->ss, result->ai_addr, result->ai_addrlen); + resolveResult->ss_len = result->ai_addrlen; + freeaddrinfo(result); + } + else + { + resolveResult->status = RESOLVE_FAILED; + } + delete req; +} diff --git a/src/network/NetworkAddress.h b/src/network/NetworkAddress.h new file mode 100644 index 0000000000..73a1d14633 --- /dev/null +++ b/src/network/NetworkAddress.h @@ -0,0 +1,71 @@ +#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers +/***************************************************************************** + * OpenRCT2, an open source clone of Roller Coaster Tycoon 2. + * + * OpenRCT2 is the work of many authors, a full list can be found in contributors.md + * For more information, visit https://github.com/OpenRCT2/OpenRCT2 + * + * OpenRCT2 is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * A full copy of the GNU General Public License can be found in licence.txt + *****************************************************************************/ +#pragma endregion + +#pragma once + +#include +#include +#include "NetworkTypes.h" +#include "../common.h" + +class NetworkAddress final +{ +public: + enum RESOLVE_STATUS + { + RESOLVE_NONE, + RESOLVE_INPROGRESS, + RESOLVE_OK, + RESOLVE_FAILED + }; + + NetworkAddress(); + ~NetworkAddress(); + + void Resolve(const char * host, uint16 port); + void ResolveAsync(const char * host, uint16 port); + + RESOLVE_STATUS GetResult(sockaddr_storage * ss, int * ss_len); + +private: + struct ResolveResult + { + RESOLVE_STATUS status; + sockaddr_storage ss; + int ss_len; + }; + + struct ResolveRequest + { + std::string Host; + uint16 Port; + std::shared_ptr Result; + }; + + /** + * Store for the async result. A new store is created for every request. + * Old requests simply write to an old store that will then be + * automatically deleted by std::shared_ptr. + */ + std::shared_ptr _result; + + /** + * Mutex so synchronoise the requests. + */ + SDL_mutex * _resolveMutex; + + static void ResolveWorker(ResolveRequest * req); +}; diff --git a/src/network/NetworkTypes.h b/src/network/NetworkTypes.h new file mode 100644 index 0000000000..d11cbd5b27 --- /dev/null +++ b/src/network/NetworkTypes.h @@ -0,0 +1,27 @@ +#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers +/***************************************************************************** + * OpenRCT2, an open source clone of Roller Coaster Tycoon 2. + * + * OpenRCT2 is the work of many authors, a full list can be found in contributors.md + * For more information, visit https://github.com/OpenRCT2/OpenRCT2 + * + * OpenRCT2 is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * A full copy of the GNU General Public License can be found in licence.txt + *****************************************************************************/ +#pragma endregion + +#pragma once + +#include + +#ifdef __WINDOWS__ + // winsock2 must be included before windows.h + #include + #include +#else + #include +#endif // __WINDOWS__ diff --git a/src/network/network.cpp b/src/network/network.cpp index d1f25ab42d..51b64cd0ae 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -537,74 +537,6 @@ void NetworkConnection::setLastDisconnectReason(const rct_string_id string_id, v setLastDisconnectReason(buffer); } -NetworkAddress::NetworkAddress() -{ - ss = std::make_shared(); - ss_len = std::make_shared(); - status = std::make_shared(); - *status = RESOLVE_NONE; -} - -void NetworkAddress::Resolve(const char* host, unsigned short port, bool nonblocking) -{ - // A non-blocking hostname resolver - *status = RESOLVE_INPROGRESS; - mutex = SDL_CreateMutex(); - cond = SDL_CreateCond(); - NetworkAddress::host = host; - NetworkAddress::port = port; - SDL_LockMutex(mutex); - SDL_Thread* thread = SDL_CreateThread(ResolveFunc, 0, this); - // The mutex/cond is to make sure ResolveFunc doesn't ever get a dangling pointer - SDL_CondWait(cond, mutex); - SDL_UnlockMutex(mutex); - SDL_DestroyCond(cond); - SDL_DestroyMutex(mutex); - if (!nonblocking) { - int status; - SDL_WaitThread(thread, &status); - } -} - -int NetworkAddress::GetResolveStatus(void) -{ - return *status; -} - -int NetworkAddress::ResolveFunc(void* pointer) -{ - // Copy data for thread safety - NetworkAddress * networkaddress = (NetworkAddress*)pointer; - SDL_LockMutex(networkaddress->mutex); - std::string host; - if (networkaddress->host) host = networkaddress->host; - std::string port = std::to_string(networkaddress->port); - std::shared_ptr ss = networkaddress->ss; - std::shared_ptr ss_len = networkaddress->ss_len; - std::shared_ptr status = networkaddress->status; - SDL_CondSignal(networkaddress->cond); - SDL_UnlockMutex(networkaddress->mutex); - - // Perform the resolve - addrinfo hints; - addrinfo* res; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - if (host.length() == 0) { - hints.ai_flags = AI_PASSIVE; - } - getaddrinfo(host.length() == 0 ? NULL : host.c_str(), port.c_str(), &hints, &res); - if (res) { - memcpy(&(*ss), res->ai_addr, res->ai_addrlen); - *ss_len = res->ai_addrlen; - *status = RESOLVE_OK; - freeaddrinfo(res); - } else { - *status = RESOLVE_FAILED; - } - return 0; -} - Network::Network() { wsa_initialized = false; @@ -706,7 +638,7 @@ bool Network::BeginClient(const char* host, unsigned short port) if (!Init()) return false; - server_address.Resolve(host, port); + server_address.ResolveAsync(host, port); status = NETWORK_STATUS_RESOLVING; char str_resolving[256]; @@ -776,11 +708,16 @@ bool Network::BeginServer(unsigned short port, const char* address) return false; _userManager.Load(); + NetworkAddress networkaddress; - networkaddress.Resolve(address, port, false); + networkaddress.Resolve(address, port); + + sockaddr_storage ss; + int ss_len; + networkaddress.GetResult(&ss, &ss_len); log_verbose("Begin listening for clients"); - listening_socket = socket(networkaddress.ss->ss_family, SOCK_STREAM, IPPROTO_TCP); + listening_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP); if (listening_socket == INVALID_SOCKET) { log_error("Unable to create socket."); return false; @@ -792,7 +729,7 @@ bool Network::BeginServer(unsigned short port, const char* address) log_error("IPV6_V6ONLY failed. %d", LAST_SOCKET_ERROR()); } - if (bind(listening_socket, (sockaddr*)&(*networkaddress.ss), (*networkaddress.ss_len)) != 0) { + if (bind(listening_socket, (sockaddr *)&ss, ss_len) != 0) { closesocket(listening_socket); log_error("Unable to bind to socket."); return false; @@ -934,8 +871,12 @@ void Network::UpdateClient() bool connectfailed = false; switch(status){ case NETWORK_STATUS_RESOLVING:{ - if (server_address.GetResolveStatus() == NetworkAddress::RESOLVE_OK) { - server_connection.socket = socket(server_address.ss->ss_family, SOCK_STREAM, IPPROTO_TCP); + sockaddr_storage ss; + int ss_len; + NetworkAddress::RESOLVE_STATUS result = server_address.GetResult(&ss, &ss_len); + + if (result == NetworkAddress::RESOLVE_OK) { + server_connection.socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP); if (server_connection.socket == INVALID_SOCKET) { log_error("Unable to create socket."); connectfailed = true; @@ -949,8 +890,9 @@ void Network::UpdateClient() break; } - if (connect(server_connection.socket, (sockaddr *)&(*server_address.ss), - (*server_address.ss_len)) == SOCKET_ERROR && (LAST_SOCKET_ERROR() == EINPROGRESS || LAST_SOCKET_ERROR() == EWOULDBLOCK)){ + if (connect(server_connection.socket, (sockaddr *)&ss, ss_len) == SOCKET_ERROR && + (LAST_SOCKET_ERROR() == EINPROGRESS || LAST_SOCKET_ERROR() == EWOULDBLOCK) + ) { char str_connecting[256]; format_string(str_connecting, STR_MULTIPLAYER_CONNECTING, NULL); window_network_status_open(str_connecting, []() -> void { @@ -963,7 +905,7 @@ void Network::UpdateClient() connectfailed = true; break; } - } else if (server_address.GetResolveStatus() == NetworkAddress::RESOLVE_INPROGRESS) { + } else if (result == NetworkAddress::RESOLVE_INPROGRESS) { break; } else { log_error("Could not resolve address."); diff --git a/src/network/network.h b/src/network/network.h index c506d1e2cb..55d1c8486a 100644 --- a/src/network/network.h +++ b/src/network/network.h @@ -123,6 +123,7 @@ extern "C" { #include #include "../core/Json.hpp" #include "../core/Nullable.hpp" +#include "NetworkAddress.h" #include "NetworkKey.h" #include "NetworkUser.h" @@ -315,33 +316,6 @@ private: uint32 last_packet_time; }; -class NetworkAddress -{ -public: - NetworkAddress(); - void Resolve(const char* host, unsigned short port, bool nonblocking = true); - int GetResolveStatus(void); - - std::shared_ptr ss; - std::shared_ptr ss_len; - - enum { - RESOLVE_NONE, - RESOLVE_INPROGRESS, - RESOLVE_OK, - RESOLVE_FAILED - }; - -private: - static int ResolveFunc(void* pointer); - - const char* host = nullptr; - unsigned short port = 0; - SDL_mutex* mutex = nullptr; - SDL_cond* cond = nullptr; - std::shared_ptr status; -}; - class Network { public: