1
0
mirror of https://github.com/OpenRCT2/OpenRCT2 synced 2026-01-22 14:24:33 +01:00

refactor network, create ITcpSocket

Abstracts all socket code into a new class TcpSocket which is only exposed by a light interface, ITcpSocket. This now means that platform specific headers like winsock2.h and sys/socket.h do not have to be included in OpenRCT2 header files reducing include load and other issues.
This commit is contained in:
Ted John
2016-06-01 22:58:21 +01:00
parent 14de1cd5eb
commit 8dfbabbd07
11 changed files with 674 additions and 486 deletions

View File

@@ -14,6 +14,15 @@
*****************************************************************************/
#pragma endregion
#include <SDL_platform.h>
#ifdef __WINDOWS__
// winsock2 must be included before windows.h
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif
extern "C" {
#include "../openrct2.h"
#include "../platform/platform.h"
@@ -161,14 +170,16 @@ void Network::Close()
return;
}
if (mode == NETWORK_MODE_CLIENT) {
closesocket(server_connection.Socket);
} else
if (mode == NETWORK_MODE_SERVER) {
closesocket(listening_socket);
delete server_connection.Socket;
server_connection.Socket = nullptr;
} else if (mode == NETWORK_MODE_SERVER) {
delete listening_socket;
listening_socket = nullptr;
}
mode = NETWORK_MODE_NONE;
status = NETWORK_STATUS_NONE;
_lastConnectStatus = SOCKET_STATUS_CLOSED;
server_connection.AuthStatus = NETWORK_AUTH_NONE;
server_connection.InboundPacket.Clear();
server_connection.SetLastDisconnectReason(nullptr);
@@ -199,14 +210,11 @@ bool Network::BeginClient(const char* host, unsigned short port)
if (!Init())
return false;
server_address.ResolveAsync(host, port);
status = NETWORK_STATUS_RESOLVING;
char str_resolving[256];
format_string(str_resolving, STR_MULTIPLAYER_RESOLVING, NULL);
window_network_status_open(str_resolving, []() -> void {
gNetwork.Close();
});
assert(server_connection.Socket == nullptr);
server_connection.Socket = CreateTcpSocket();
server_connection.Socket->ConnectAsync(host, port);
status = NETWORK_STATUS_CONNECTING;
_lastConnectStatus = SOCKET_STATUS_CLOSED;
BeginChatLog();
@@ -272,43 +280,11 @@ bool Network::BeginServer(unsigned short port, const char* address)
_userManager.Load();
NetworkAddress networkaddress;
networkaddress.Resolve(address, port);
sockaddr_storage ss;
int ss_len;
networkaddress.GetResult(&ss, &ss_len);
log_verbose("Begin listening for clients");
listening_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (listening_socket == INVALID_SOCKET) {
log_error("Unable to create socket.");
return false;
}
// Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections
int value = 0;
if (setsockopt(listening_socket, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&value, sizeof(value)) != 0) {
log_error("IPV6_V6ONLY failed. %d", LAST_SOCKET_ERROR());
}
if (bind(listening_socket, (sockaddr *)&ss, ss_len) != 0) {
closesocket(listening_socket);
log_error("Unable to bind to socket.");
return false;
}
if (listen(listening_socket, SOMAXCONN) != 0) {
closesocket(listening_socket);
log_error("Unable to listen on socket.");
return false;
}
if (!NetworkConnection::SetNonBlocking(listening_socket, true)) {
closesocket(listening_socket);
log_error("Failed to set non-blocking mode.");
return false;
}
assert(listening_socket == nullptr);
listening_socket = CreateTcpSocket();
listening_socket->Listen(address, port);
ServerName = gConfigNetwork.server_name;
ServerDescription = gConfigNetwork.server_description;
@@ -421,113 +397,74 @@ void Network::UpdateServer()
break;
}
SOCKET socket = accept(listening_socket, NULL, NULL);
if (socket == INVALID_SOCKET) {
if (LAST_SOCKET_ERROR() != EWOULDBLOCK) {
PrintError();
log_error("Failed to accept client.");
}
} else {
if (!NetworkConnection::SetNonBlocking(socket, true)) {
closesocket(socket);
log_error("Failed to set non-blocking mode.");
} else {
AddClient(socket);
}
ITcpSocket * tcpSocket = listening_socket->Accept();
if (tcpSocket != nullptr) {
AddClient(tcpSocket);
}
}
void Network::UpdateClient()
{
bool connectfailed = false;
switch(status){
case NETWORK_STATUS_RESOLVING:{
sockaddr_storage ss;
int ss_len;
NetworkAddress::RESOLVE_STATUS result = server_address.GetResult(&ss, &ss_len);
if (result == NetworkAddress::RESOLVE_OK) {
server_connection.Socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (server_connection.Socket == INVALID_SOCKET) {
log_error("Unable to create socket.");
connectfailed = true;
break;
case NETWORK_STATUS_CONNECTING:
{
switch (server_connection.Socket->GetStatus()) {
case SOCKET_STATUS_RESOLVING:
{
if (_lastConnectStatus != SOCKET_STATUS_RESOLVING)
{
_lastConnectStatus = SOCKET_STATUS_RESOLVING;
char str_resolving[256];
format_string(str_resolving, STR_MULTIPLAYER_RESOLVING, NULL);
window_network_status_open(str_resolving, []() -> void {
gNetwork.Close();
});
}
server_connection.SetTCPNoDelay(true);
if (!server_connection.SetNonBlocking(true)) {
log_error("Failed to set non-blocking mode.");
connectfailed = true;
break;
}
if (connect(server_connection.Socket, (sockaddr *)&ss, ss_len) == SOCKET_ERROR &&
(LAST_SOCKET_ERROR() == EINPROGRESS || LAST_SOCKET_ERROR() == EWOULDBLOCK)
) {
break;
}
case SOCKET_STATUS_CONNECTING:
{
if (_lastConnectStatus != SOCKET_STATUS_CONNECTING)
{
_lastConnectStatus = SOCKET_STATUS_CONNECTING;
char str_connecting[256];
format_string(str_connecting, STR_MULTIPLAYER_CONNECTING, NULL);
window_network_status_open(str_connecting, []() -> void {
gNetwork.Close();
});
server_connect_time = SDL_GetTicks();
status = NETWORK_STATUS_CONNECTING;
} else {
log_error("connect() failed %d", LAST_SOCKET_ERROR());
connectfailed = true;
break;
}
} else if (result == NetworkAddress::RESOLVE_INPROGRESS) {
break;
} else {
log_error("Could not resolve address.");
connectfailed = true;
}
}break;
case NETWORK_STATUS_CONNECTING:{
int error = 0;
socklen_t len = sizeof(error);
int result = getsockopt(server_connection.Socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len);
if (result != 0) {
log_error("getsockopt failed with error %d", LAST_SOCKET_ERROR());
break;
}
if (error != 0) {
log_error("Connection failed %d", error);
connectfailed = true;
case NETWORK_STATUS_CONNECTED:
{
status = NETWORK_STATUS_CONNECTED;
server_connection.ResetLastPacketTime();
Client_Send_TOKEN();
char str_authenticating[256];
format_string(str_authenticating, STR_MULTIPLAYER_AUTHENTICATING, NULL);
window_network_status_open(str_authenticating, []() -> void {
gNetwork.Close();
});
break;
}
if (SDL_TICKS_PASSED(SDL_GetTicks(), server_connect_time + 3000)) {
log_error("Connection timed out.");
connectfailed = true;
break;
}
fd_set writeFD;
FD_ZERO(&writeFD);
FD_SET(server_connection.Socket, &writeFD);
timeval timeout;
timeout.tv_sec = 0;
timeout.tv_usec = 0;
if (select(server_connection.Socket + 1, NULL, &writeFD, NULL, &timeout) > 0) {
error = 0;
socklen_t len = sizeof(error);
result = getsockopt(server_connection.Socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len);
if (result != 0) {
log_error("getsockopt failed with error %d", LAST_SOCKET_ERROR());
break;
}
if (error == 0) {
status = NETWORK_STATUS_CONNECTED;
server_connection.ResetLastPacketTime();
Client_Send_TOKEN();
char str_authenticating[256];
format_string(str_authenticating, STR_MULTIPLAYER_AUTHENTICATING, NULL);
window_network_status_open(str_authenticating, []() -> void {
gNetwork.Close();
});
default:
{
const char * error = server_connection.Socket->GetError();
if (error != nullptr) {
Console::Error::WriteLine(error);
}
Close();
window_network_status_close();
window_error_open(STR_UNABLE_TO_CONNECT_TO_SERVER, STR_NONE);
break;
}
}break;
}
break;
}
case NETWORK_STATUS_CONNECTED:
{
if (!ProcessConnection(server_connection)) {
// Do not show disconnect message window when password window closed/canceled
if (server_connection.AuthStatus == NETWORK_AUTH_REQUIREPASSWORD) {
@@ -560,11 +497,6 @@ void Network::UpdateClient()
}
break;
}
if (connectfailed) {
Close();
window_network_status_close();
window_error_open(STR_UNABLE_TO_CONNECT_TO_SERVER, STR_NONE);
}
}
@@ -659,7 +591,7 @@ void Network::KickPlayer(int playerId)
char str_disconnect_msg[256];
format_string(str_disconnect_msg, STR_MULTIPLAYER_KICKED_REASON, NULL);
Server_Send_SETDISCONNECTMSG(*(*it), str_disconnect_msg);
shutdown((*it)->Socket, SHUT_RD);
(*it)->Socket->Disconnect();
(*it)->SendQueuedPackets();
break;
}
@@ -674,7 +606,7 @@ void Network::SetPassword(const char* password)
void Network::ShutdownClient()
{
if (GetMode() == NETWORK_MODE_CLIENT) {
shutdown(server_connection.Socket, SHUT_RDWR);
server_connection.Socket->Disconnect();
}
}
@@ -1035,8 +967,8 @@ void Network::Server_Send_AUTH(NetworkConnection& connection)
}
connection.QueuePacket(std::move(packet));
if (connection.AuthStatus != NETWORK_AUTH_OK && connection.AuthStatus != NETWORK_AUTH_REQUIREPASSWORD) {
shutdown(connection.Socket, SHUT_RD);
connection.SendQueuedPackets();
connection.Socket->Disconnect();
}
}
@@ -1324,11 +1256,10 @@ void Network::ProcessGameCommandQueue()
}
}
void Network::AddClient(SOCKET socket)
void Network::AddClient(ITcpSocket * socket)
{
auto connection = std::unique_ptr<NetworkConnection>(new NetworkConnection); // change to make_unique in c++14
connection->Socket = socket;
connection->SetTCPNoDelay(true);
client_connection_list.push_back(std::move(connection));
}
@@ -1442,21 +1373,6 @@ std::string Network::MakePlayerNameUnique(const std::string &name)
return new_name;
}
void Network::PrintError()
{
#ifdef __WINDOWS__
wchar_t *s = NULL;
FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL,
LAST_SOCKET_ERROR(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, NULL);
fprintf(stderr, "%S\n", s);
LocalFree(s);
#else
char *s = strerror(LAST_SOCKET_ERROR());
fprintf(stderr, "%s\n", s);
#endif
}
void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet)
{
utf8 keyPath[MAX_PATH];
@@ -1471,7 +1387,7 @@ void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket&
if (!ok) {
log_error("Failed to load key %s", keyPath);
connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
return;
}
uint32 challenge_size;
@@ -1486,7 +1402,7 @@ void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket&
if (!ok) {
log_error("Failed to sign server's challenge.");
connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
return;
}
// Don't keep private key in memory. There's no need and it may get leaked
@@ -1505,37 +1421,37 @@ void Network::Client_Handle_AUTH(NetworkConnection& connection, NetworkPacket& p
break;
case NETWORK_AUTH_BADNAME:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PLAYER_NAME);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_BADVERSION:
{
const char *version = packet.ReadString();
connection.SetLastDisconnectReason(STR_MULTIPLAYER_INCORRECT_SOFTWARE_VERSION, &version);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
}
case NETWORK_AUTH_BADPASSWORD:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PASSWORD);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_VERIFICATIONFAILURE:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_FULL:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_SERVER_FULL);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_REQUIREPASSWORD:
window_network_status_open_password();
break;
case NETWORK_AUTH_UNKNOWN_KEY_DISALLOWED:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_UNKNOWN_KEY_DISALLOWED);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
default:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_INCORRECT_SOFTWARE_VERSION);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
}
}
@@ -1964,6 +1880,19 @@ void Network::Client_Handle_GAMEINFO(NetworkConnection& connection, NetworkPacke
json_decref(root);
}
namespace Convert
{
uint16 HostToNetwork(uint16 value)
{
return htons(value);
}
uint16 NetworkToHost(uint16 value)
{
return ntohs(value);
}
}
int network_init()
{
return gNetwork.Init();