diff --git a/src/openrct2/core/Endianness.h b/src/openrct2/core/Endianness.h index 22db6e49e1..5bebd494b6 100644 --- a/src/openrct2/core/Endianness.h +++ b/src/openrct2/core/Endianness.h @@ -11,6 +11,9 @@ #include "../common.h" +#include +#include + template struct ByteSwapT { }; @@ -58,6 +61,22 @@ template<> struct ByteSwapT<8> template static T ByteSwapBE(const T& value) { using ByteSwap = ByteSwapT; - typename ByteSwap::UIntType result = ByteSwap::SwapBE(reinterpret_cast(value)); - return *reinterpret_cast(&result); + using UIntType = typename ByteSwap::UIntType; + + if constexpr (std::is_enum_v || std::is_integral_v) + { + auto result = ByteSwap::SwapBE(static_cast(value)); + return static_cast(result); + } + else + { + // Complex type, reinterpret_cast is not safe for this case. + // Create a temporary of size(T) as unsigned type via copy instead. + UIntType temp; + std::memcpy(&temp, &value, sizeof(T)); + auto result = ByteSwap::SwapBE(temp); + T res; + std::memcpy(&res, &result, sizeof(T)); + return res; + } } diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index c2973b8cf0..01769bd5bc 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -721,7 +721,7 @@ const char* NetworkBase::FormatChat(NetworkPlayer* fromplayer, const char* text) return formatted; } -void NetworkBase::SendPacketToClients(NetworkPacket& packet, bool front, bool gameCmd) +void NetworkBase::SendPacketToClients(const NetworkPacket& packet, bool front, bool gameCmd) { for (auto& client_connection : client_connection_list) { @@ -742,7 +742,8 @@ void NetworkBase::SendPacketToClients(NetworkPacket& packet, bool front, bool ga continue; } } - client_connection->QueuePacket(NetworkPacket::Duplicate(packet), front); + auto packetCopy = packet; + client_connection->QueuePacket(std::move(packetCopy), front); } } @@ -1207,16 +1208,16 @@ void NetworkBase::Client_Send_RequestGameState(uint32_t tick) } log_verbose("Requesting gamestate from server for tick %u", tick); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::RequestGameState) << tick; + + NetworkPacket packet(NetworkCommand::RequestGameState); + packet << tick; _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Client_Send_TOKEN() { log_verbose("requesting token"); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Token); + NetworkPacket packet(NetworkCommand::Token); _serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED; _serverConnection->QueuePacket(std::move(packet)); } @@ -1224,15 +1225,14 @@ void NetworkBase::Client_Send_TOKEN() void NetworkBase::Client_Send_AUTH( const std::string& name, const std::string& password, const std::string& pubkey, const std::vector& signature) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Auth); - packet->WriteString(network_get_version().c_str()); - packet->WriteString(name.c_str()); - packet->WriteString(password.c_str()); - packet->WriteString(pubkey.c_str()); + NetworkPacket packet(NetworkCommand::Auth); + packet.WriteString(network_get_version().c_str()); + packet.WriteString(name.c_str()); + packet.WriteString(password.c_str()); + packet.WriteString(pubkey.c_str()); assert(signature.size() <= static_cast(UINT32_MAX)); - *packet << static_cast(signature.size()); - packet->Write(signature.data(), signature.size()); + packet << static_cast(signature.size()); + packet.Write(signature.data(), signature.size()); _serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED; _serverConnection->QueuePacket(std::move(packet)); } @@ -1240,21 +1240,21 @@ void NetworkBase::Client_Send_AUTH( void NetworkBase::Client_Send_MAPREQUEST(const std::vector& objects) { log_verbose("client requests %u objects", uint32_t(objects.size())); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::MapRequest) << static_cast(objects.size()); + NetworkPacket packet(NetworkCommand::MapRequest); + packet << static_cast(objects.size()); for (const auto& object : objects) { log_verbose("client requests object %s", object.c_str()); - packet->Write(reinterpret_cast(object.c_str()), 8); + packet.Write(reinterpret_cast(object.c_str()), 8); } _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_TOKEN(NetworkConnection& connection) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Token) << static_cast(connection.Challenge.size()); - packet->Write(connection.Challenge.data(), connection.Challenge.size()); + NetworkPacket packet(NetworkCommand::Token); + packet << static_cast(connection.Challenge.size()); + packet.Write(connection.Challenge.data(), connection.Challenge.size()); connection.QueuePacket(std::move(packet)); } @@ -1265,9 +1265,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST( if (objects.empty()) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::ObjectsList) << static_cast(0) - << static_cast(objects.size()); + NetworkPacket packet(NetworkCommand::ObjectsList); + packet << static_cast(0) << static_cast(objects.size()); connection.QueuePacket(std::move(packet)); } @@ -1277,13 +1276,12 @@ void NetworkBase::Server_Send_OBJECTS_LIST( { const auto* object = objects[i]; - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::ObjectsList) << static_cast(i) - << static_cast(objects.size()); + NetworkPacket packet(NetworkCommand::ObjectsList); + packet << static_cast(i) << static_cast(objects.size()); log_verbose("Object %.8s (checksum %x)", object->ObjectEntry.name, object->ObjectEntry.checksum); - packet->Write(reinterpret_cast(object->ObjectEntry.name), 8); - *packet << object->ObjectEntry.checksum << object->ObjectEntry.flags; + packet.Write(reinterpret_cast(object->ObjectEntry.name), 8); + packet << object->ObjectEntry.checksum << object->ObjectEntry.flags; connection.QueuePacket(std::move(packet)); } @@ -1292,8 +1290,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST( void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Scripts); + NetworkPacket packet(NetworkCommand::Scripts); + # ifdef ENABLE_SCRIPTING using namespace OpenRCT2::Scripting; @@ -1310,18 +1308,18 @@ void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const } log_verbose("Server sends %u scripts", pluginsToSend.size()); - *packet << static_cast(pluginsToSend.size()); + packet << static_cast(pluginsToSend.size()); for (const auto& plugin : pluginsToSend) { const auto& metadata = plugin->GetMetadata(); log_verbose("Script %s", metadata.Name.c_str()); const auto& code = plugin->GetCode(); - *packet << static_cast(code.size()); - packet->Write(reinterpret_cast(code.c_str()), code.size()); + packet << static_cast(code.size()); + packet.Write(reinterpret_cast(code.c_str()), code.size()); } # else - *packet << static_cast(0); + packet << static_cast(0); # endif connection.QueuePacket(std::move(packet)); } @@ -1330,9 +1328,7 @@ void NetworkBase::Client_Send_HEARTBEAT(NetworkConnection& connection) const { log_verbose("Sending heartbeat"); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Heartbeat); - + NetworkPacket packet(NetworkCommand::Heartbeat); connection.QueuePacket(std::move(packet)); } @@ -1364,11 +1360,11 @@ void NetworkBase::Server_Send_AUTH(NetworkConnection& connection) { new_playerid = connection.Player->Id; } - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Auth) << static_cast(connection.AuthStatus) << new_playerid; + NetworkPacket packet(NetworkCommand::Auth); + packet << static_cast(connection.AuthStatus) << new_playerid; if (connection.AuthStatus == NETWORK_AUTH_BADVERSION) { - packet->WriteString(network_get_version().c_str()); + packet.WriteString(network_get_version().c_str()); } connection.QueuePacket(std::move(packet)); if (connection.AuthStatus != NETWORK_AUTH_OK && connection.AuthStatus != NETWORK_AUTH_REQUIREPASSWORD) @@ -1408,16 +1404,16 @@ void NetworkBase::Server_Send_MAP(NetworkConnection* connection) for (size_t i = 0; i < out_size; i += chunksize) { size_t datasize = std::min(chunksize, out_size - i); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Map) << static_cast(out_size) << static_cast(i); - packet->Write(&header[i], datasize); + NetworkPacket packet(NetworkCommand::Map); + packet << static_cast(out_size) << static_cast(i); + packet.Write(&header[i], datasize); if (connection) { connection->QueuePacket(std::move(packet)); } else { - SendPacketToClients(*packet); + SendPacketToClients(packet); } } free(header); @@ -1478,22 +1474,20 @@ uint8_t* NetworkBase::save_for_network(size_t& out_size, const std::vector packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Chat); - packet->WriteString(text); + NetworkPacket packet(NetworkCommand::Chat); + packet.WriteString(text); _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& playerIds) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Chat); - packet->WriteString(text); + NetworkPacket packet(NetworkCommand::Chat); + packet.WriteString(text); if (playerIds.empty()) { // Empty players / default value means send to all players - SendPacketToClients(*packet); + SendPacketToClients(packet); } else { @@ -1502,7 +1496,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& auto conn = GetPlayerConnection(playerId); if (conn != nullptr && !conn->IsDisconnected) { - conn->QueuePacket(NetworkPacket::Duplicate(*packet)); + conn->QueuePacket(packet); } } } @@ -1510,7 +1504,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action) { - std::unique_ptr packet(NetworkPacket::Allocate()); + NetworkPacket packet(NetworkCommand::GameAction); uint32_t networkId = 0; networkId = ++_actionId; @@ -1525,26 +1519,26 @@ void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action) DataSerialiser stream(true); action->Serialise(stream); - *packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; + packet << gCurrentTicks << action->GetType() << stream; _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_GAME_ACTION(const GameAction* action) { - std::unique_ptr packet(NetworkPacket::Allocate()); + NetworkPacket packet(NetworkCommand::GameAction); DataSerialiser stream(true); action->Serialise(stream); - *packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; + packet << gCurrentTicks << action->GetType() << stream; - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_TICK() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Tick) << gCurrentTicks << scenario_rand_state().s0; + NetworkPacket packet(NetworkCommand::Tick); + packet << gCurrentTicks << scenario_rand_state().s0; uint32_t flags = 0; // Simple counter which limits how often a sprite checksum gets sent. // This can get somewhat expensive, so we don't want to push it every tick in release, @@ -1558,75 +1552,72 @@ void NetworkBase::Server_Send_TICK() } // Send flags always, so we can understand packet structure on the other end, // and allow for some expansion. - *packet << flags; + packet << flags; if (flags & NETWORK_TICK_FLAG_CHECKSUMS) { rct_sprite_checksum checksum = sprite_checksum(); - packet->WriteString(checksum.ToString().c_str()); + packet.WriteString(checksum.ToString().c_str()); } - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_PLAYERINFO(int32_t playerId) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::PlayerInfo) << gCurrentTicks; + NetworkPacket packet(NetworkCommand::PlayerInfo); + packet << gCurrentTicks; auto* player = GetPlayerByID(playerId); if (player == nullptr) return; - player->Write(*packet); - SendPacketToClients(*packet); + player->Write(packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_PLAYERLIST() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::PlayerList) << gCurrentTicks << static_cast(player_list.size()); + NetworkPacket packet(NetworkCommand::PlayerList); + packet << gCurrentTicks << static_cast(player_list.size()); for (auto& player : player_list) { - player->Write(*packet); + player->Write(packet); } - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Client_Send_PING() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Ping); + NetworkPacket packet(NetworkCommand::Ping); _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_PING() { last_ping_sent_time = platform_get_ticks(); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Ping); + NetworkPacket packet(NetworkCommand::Ping); for (auto& client_connection : client_connection_list) { client_connection->PingTime = platform_get_ticks(); } - SendPacketToClients(*packet, true); + SendPacketToClients(packet, true); } void NetworkBase::Server_Send_PINGLIST() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::PingList) << static_cast(player_list.size()); + NetworkPacket packet(NetworkCommand::PingList); + packet << static_cast(player_list.size()); for (auto& player : player_list) { - *packet << player->Id << player->Ping; + packet << player->Id << player->Ping; } - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_SETDISCONNECTMSG(NetworkConnection& connection, const char* msg) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::DisconnectMessage); - packet->WriteString(msg); + NetworkPacket packet(NetworkCommand::DisconnectMessage); + packet.WriteString(msg); connection.QueuePacket(std::move(packet)); } @@ -1646,8 +1637,7 @@ json_t* NetworkBase::GetServerInfoAsJson() const void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::GameInfo); + NetworkPacket packet(NetworkCommand::GameInfo); # ifndef DISABLE_HTTP json_t* obj = GetServerInfoAsJson(); @@ -1658,8 +1648,8 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) json_object_set_new(jsonProvider, "website", json_string(gConfigNetwork.provider_website.c_str())); json_object_set_new(obj, "provider", jsonProvider); - packet->WriteString(json_dumps(obj, 0)); - *packet << _serverState.gamestateSnapshotsEnabled; + packet.WriteString(json_dumps(obj, 0)); + packet << _serverState.gamestateSnapshotsEnabled; json_decref(obj); # endif @@ -1668,39 +1658,37 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) void NetworkBase::Server_Send_SHOWERROR(NetworkConnection& connection, rct_string_id title, rct_string_id message) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::ShowError) << title << message; + NetworkPacket packet(NetworkCommand::ShowError); + packet << title << message; connection.QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_GROUPLIST(NetworkConnection& connection) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::GroupList) << static_cast(group_list.size()) << default_group; + NetworkPacket packet(NetworkCommand::GroupList); + packet << static_cast(group_list.size()) << default_group; for (auto& group : group_list) { - group->Write(*packet); + group->Write(packet); } connection.QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_EVENT_PLAYER_JOINED(const char* playerName) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Event); - *packet << static_cast(SERVER_EVENT_PLAYER_JOINED); - packet->WriteString(playerName); - SendPacketToClients(*packet); + NetworkPacket packet(NetworkCommand::Event); + packet << static_cast(SERVER_EVENT_PLAYER_JOINED); + packet.WriteString(playerName); + SendPacketToClients(packet); } void NetworkBase::Server_Send_EVENT_PLAYER_DISCONNECTED(const char* playerName, const char* reason) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Event); - *packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); - packet->WriteString(playerName); - packet->WriteString(reason); - SendPacketToClients(*packet); + NetworkPacket packet(NetworkCommand::Event); + packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); + packet.WriteString(playerName); + packet.WriteString(reason); + SendPacketToClients(packet); } bool NetworkBase::ProcessConnection(NetworkConnection& connection) @@ -1748,11 +1736,8 @@ bool NetworkBase::ProcessConnection(NetworkConnection& connection) void NetworkBase::ProcessPacket(NetworkConnection& connection, NetworkPacket& packet) { - std::underlying_type::type command; - packet >> command; - const auto& handlerList = GetMode() == NETWORK_MODE_SERVER ? server_command_handlers : client_command_handlers; - auto it = handlerList.find(static_cast(command)); + auto it = handlerList.find(packet.GetCommand()); if (it != handlerList.end()) { auto commandHandler = it->second; @@ -2240,11 +2225,11 @@ void NetworkBase::Server_Handle_REQUEST_GAMESTATE(NetworkConnection& connection, dataSize = snapshotMemory.GetLength() - bytesSent; } - std::unique_ptr gameStateChunk(NetworkPacket::Allocate()); - *gameStateChunk << static_cast(NetworkCommand::GameState) << tick << length << bytesSent << dataSize; - gameStateChunk->Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); + NetworkPacket packetGameStateChunk(NetworkCommand::GameState); + packetGameStateChunk << tick << length << bytesSent << dataSize; + packetGameStateChunk.Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); - connection.QueuePacket(std::move(gameStateChunk)); + connection.QueuePacket(std::move(packetGameStateChunk)); bytesSent += dataSize; } @@ -2662,7 +2647,7 @@ void NetworkBase::Client_Handle_MAP([[maybe_unused]] NetworkConnection& connecti { uint32_t size, offset; packet >> size >> offset; - int32_t chunksize = static_cast(packet.Size - packet.BytesRead); + int32_t chunksize = static_cast(packet.Header.Size - packet.BytesRead); if (chunksize <= 0) { return; @@ -2937,7 +2922,7 @@ void NetworkBase::Client_Handle_GAME_ACTION([[maybe_unused]] NetworkConnection& packet >> tick >> actionType; MemoryStream stream; - size_t size = packet.Size - packet.BytesRead; + const size_t size = packet.Header.Size - packet.BytesRead; stream.WriteArray(packet.Read(size), size); stream.SetPosition(0); @@ -3027,7 +3012,7 @@ void NetworkBase::Server_Handle_GAME_ACTION(NetworkConnection& connection, Netwo } DataSerialiser stream(false); - size_t size = packet.Size - packet.BytesRead; + const size_t size = packet.Header.Size - packet.BytesRead; stream.GetStream().WriteArray(packet.Read(size), size); stream.GetStream().SetPosition(0); @@ -3208,8 +3193,7 @@ void NetworkBase::Client_Handle_EVENT([[maybe_unused]] NetworkConnection& connec void NetworkBase::Client_Send_GAMEINFO() { log_verbose("requesting gameinfo"); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::GameInfo); + NetworkPacket packet(NetworkCommand::GameInfo); _serverConnection->QueuePacket(std::move(packet)); } diff --git a/src/openrct2/network/NetworkBase.h b/src/openrct2/network/NetworkBase.h index ad8eed136e..0f2f1fd0b5 100644 --- a/src/openrct2/network/NetworkBase.h +++ b/src/openrct2/network/NetworkBase.h @@ -113,7 +113,7 @@ public: // Client void ProcessPlayerInfo(); void ProcessDisconnectedClients(); static const char* FormatChat(NetworkPlayer* fromplayer, const char* text); - void SendPacketToClients(NetworkPacket& packet, bool front = false, bool gameCmd = false); + void SendPacketToClients(const NetworkPacket& packet, bool front = false, bool gameCmd = false); bool CheckSRAND(uint32_t tick, uint32_t srand0); bool CheckDesynchronizaton(); void RequestStateSnapshot(); diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index 270e68098b..1f21a8c3b1 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -18,6 +18,7 @@ # include "network.h" constexpr size_t NETWORK_DISCONNECT_REASON_BUFFER_SIZE = 256; +constexpr size_t NetworkBufferSize = 1024; NetworkConnection::NetworkConnection() { @@ -31,47 +32,61 @@ NetworkConnection::~NetworkConnection() int32_t NetworkConnection::ReadPacket() { - if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Size)) + size_t bytesRead = 0; + + // Read packet header. + auto& header = InboundPacket.Header; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - // read packet size - void* buffer = &(reinterpret_cast(&InboundPacket.Size))[InboundPacket.BytesTransferred]; - size_t bufferLength = sizeof(InboundPacket.Size) - InboundPacket.BytesTransferred; - size_t readBytes; - NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes); + const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred; + + uint8_t* buffer = reinterpret_cast(&InboundPacket.Header); + + NETWORK_READPACKET status = Socket->ReceiveData(buffer, missingLength, &bytesRead); if (status != NETWORK_READPACKET_SUCCESS) { return status; } - InboundPacket.BytesTransferred += readBytes; - if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size)) + InboundPacket.BytesTransferred += bytesRead; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - InboundPacket.Size = Convert::NetworkToHost(InboundPacket.Size); - if (InboundPacket.Size == 0) // Can't have a size 0 packet - { - return NETWORK_READPACKET_DISCONNECTED; - } - InboundPacket.Data->resize(InboundPacket.Size); + // If still not enough data for header, keep waiting. + return NETWORK_READPACKET_MORE_DATA; } + + // Normalise values. + header.Size = Convert::NetworkToHost(header.Size); + header.Id = ByteSwapBE(header.Id); + + // NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size. + // Previously the Id field was not part of the header rather part of the body. + header.Size -= sizeof(header.Id); + + // Fall-through: Read rest of packet. } - else + + // Read packet body. { - // read packet data - if (InboundPacket.Data->capacity() > 0) + const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header)); + + uint8_t buffer[NetworkBufferSize]; + + if (missingLength > 0) { - void* buffer = &InboundPacket.GetData()[InboundPacket.BytesTransferred - sizeof(InboundPacket.Size)]; - size_t bufferLength = sizeof(InboundPacket.Size) + InboundPacket.Size - InboundPacket.BytesTransferred; - size_t readBytes; - NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes); + NETWORK_READPACKET status = Socket->ReceiveData(buffer, std::min(missingLength, NetworkBufferSize), &bytesRead); if (status != NETWORK_READPACKET_SUCCESS) { return status; } - InboundPacket.BytesTransferred += readBytes; + InboundPacket.BytesTransferred += bytesRead; + InboundPacket.Write(buffer, bytesRead); } - if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size) + InboundPacket.Size) + + if (InboundPacket.Data.size() == header.Size) { + // Received complete packet. _lastPacketTime = platform_get_ticks(); RecordPacketStats(InboundPacket, false); @@ -79,26 +94,34 @@ int32_t NetworkConnection::ReadPacket() return NETWORK_READPACKET_SUCCESS; } } + return NETWORK_READPACKET_MORE_DATA; } bool NetworkConnection::SendPacket(NetworkPacket& packet) { - uint16_t sizen = Convert::HostToNetwork(packet.Size); - std::vector tosend; - tosend.reserve(sizeof(sizen) + packet.Size); - tosend.insert(tosend.end(), reinterpret_cast(&sizen), reinterpret_cast(&sizen) + sizeof(sizen)); - tosend.insert(tosend.end(), packet.Data->begin(), packet.Data->end()); + auto header = packet.Header; - const void* buffer = &tosend[packet.BytesTransferred]; - size_t bufferSize = tosend.size() - packet.BytesTransferred; - size_t sent = Socket->SendData(buffer, bufferSize); + std::vector buffer; + buffer.reserve(sizeof(header) + header.Size); + + // NOTE: For compatibility reasons for the master server we need to add sizeof(Header.Id) to the size. + // Previously the Id field was not part of the header rather part of the body. + header.Size += sizeof(header.Id); + header.Size = Convert::HostToNetwork(header.Size); + header.Id = ByteSwapBE(header.Id); + + buffer.insert(buffer.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); + buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end()); + + size_t bufferSize = buffer.size() - packet.BytesTransferred; + size_t sent = Socket->SendData(buffer.data() + packet.BytesTransferred, bufferSize); if (sent > 0) { packet.BytesTransferred += sent; } - bool sendComplete = packet.BytesTransferred == tosend.size(); + bool sendComplete = packet.BytesTransferred == buffer.size(); if (sendComplete) { RecordPacketStats(packet, true); @@ -106,15 +129,15 @@ bool NetworkConnection::SendPacket(NetworkPacket& packet) return sendComplete; } -void NetworkConnection::QueuePacket(std::unique_ptr packet, bool front) +void NetworkConnection::QueuePacket(NetworkPacket&& packet, bool front) { - if (AuthStatus == NETWORK_AUTH_OK || !packet->CommandRequiresAuth()) + if (AuthStatus == NETWORK_AUTH_OK || !packet.CommandRequiresAuth()) { - packet->Size = static_cast(packet->Data->size()); + packet.Header.Size = static_cast(packet.Data.size()); if (front) { // If the first packet was already partially sent add new packet to second position - if (!_outboundPackets.empty() && _outboundPackets.front()->BytesTransferred > 0) + if (!_outboundPackets.empty() && _outboundPackets.front().BytesTransferred > 0) { auto it = _outboundPackets.begin(); it++; // Second position @@ -134,9 +157,9 @@ void NetworkConnection::QueuePacket(std::unique_ptr packet, bool void NetworkConnection::SendQueuedPackets() { - while (!_outboundPackets.empty() && SendPacket(*_outboundPackets.front())) + while (!_outboundPackets.empty() && SendPacket(_outboundPackets.front())) { - _outboundPackets.remove(_outboundPackets.front()); + _outboundPackets.pop_front(); } } diff --git a/src/openrct2/network/NetworkConnection.h b/src/openrct2/network/NetworkConnection.h index 0dcd379e73..77000bf04e 100644 --- a/src/openrct2/network/NetworkConnection.h +++ b/src/openrct2/network/NetworkConnection.h @@ -16,7 +16,7 @@ # include "NetworkTypes.h" # include "Socket.h" -# include +# include # include # include @@ -41,7 +41,13 @@ public: ~NetworkConnection(); int32_t ReadPacket(); - void QueuePacket(std::unique_ptr packet, bool front = false); + void QueuePacket(NetworkPacket&& packet, bool front = false); + void QueuePacket(const NetworkPacket& packet, bool front = false) + { + auto copy = packet; + return QueuePacket(std::move(copy), front); + } + void SendQueuedPackets(); void ResetLastPacketTime(); bool ReceivedPacketRecently(); @@ -51,7 +57,7 @@ public: void SetLastDisconnectReason(const rct_string_id string_id, void* args = nullptr); private: - std::list> _outboundPackets; + std::deque _outboundPackets; uint32_t _lastPacketTime = 0; utf8* _lastDisconnectReason = nullptr; diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index e1e4d860c6..c781fd1978 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -15,35 +15,31 @@ # include -std::unique_ptr NetworkPacket::Allocate() +NetworkPacket::NetworkPacket(NetworkCommand id) + : Header{ 0, id } { - return std::make_unique(); -} - -std::unique_ptr NetworkPacket::Duplicate(NetworkPacket& packet) -{ - return std::make_unique(packet); } uint8_t* NetworkPacket::GetData() { - return &(*Data)[0]; + return Data.data(); +} + +const uint8_t* NetworkPacket::GetData() const +{ + return Data.data(); } NetworkCommand NetworkPacket::GetCommand() const { - if (Data->size() < sizeof(uint32_t)) - return NetworkCommand::Invalid; - - const uint32_t commandId = ByteSwapBE(*reinterpret_cast(&(*Data)[0])); - return static_cast(commandId); + return Header.Id; } void NetworkPacket::Clear() { BytesTransferred = 0; BytesRead = 0; - Data->clear(); + Data.clear(); } bool NetworkPacket::CommandRequiresAuth() @@ -63,9 +59,10 @@ bool NetworkPacket::CommandRequiresAuth() } } -void NetworkPacket::Write(const uint8_t* bytes, size_t size) +void NetworkPacket::Write(const void* bytes, size_t size) { - Data->insert(Data->end(), bytes, bytes + size); + const uint8_t* src = reinterpret_cast(bytes); + Data.insert(Data.end(), src, src + size); } void NetworkPacket::WriteString(const utf8* string) @@ -75,7 +72,7 @@ void NetworkPacket::WriteString(const utf8* string) const uint8_t* NetworkPacket::Read(size_t size) { - if (BytesRead + size > NetworkPacket::Size) + if (BytesRead + size > Header.Size) { return nullptr; } @@ -91,7 +88,7 @@ const utf8* NetworkPacket::ReadString() { char* str = reinterpret_cast(&GetData()[BytesRead]); char* strend = str; - while (BytesRead < Size && *strend != 0) + while (BytesRead < Header.Size && *strend != 0) { BytesRead++; strend++; diff --git a/src/openrct2/network/NetworkPacket.h b/src/openrct2/network/NetworkPacket.h index 507082e81e..ffa464333a 100644 --- a/src/openrct2/network/NetworkPacket.h +++ b/src/openrct2/network/NetworkPacket.h @@ -16,18 +16,23 @@ #include #include -class NetworkPacket final +#pragma pack(push, 1) +struct PacketHeader { -public: uint16_t Size = 0; - std::shared_ptr> Data = std::make_shared>(); - size_t BytesTransferred = 0; - size_t BytesRead = 0; + NetworkCommand Id = NetworkCommand::Invalid; +}; +static_assert(sizeof(PacketHeader) == 6); +#pragma pack(pop) - static std::unique_ptr Allocate(); - static std::unique_ptr Duplicate(NetworkPacket& packet); +struct NetworkPacket final +{ + NetworkPacket() = default; + NetworkPacket(NetworkCommand id); uint8_t* GetData(); + const uint8_t* GetData() const; + NetworkCommand GetCommand() const; void Clear(); @@ -36,12 +41,12 @@ public: const uint8_t* Read(size_t size); const utf8* ReadString(); - void Write(const uint8_t* bytes, size_t size); + void Write(const void* bytes, size_t size); void WriteString(const utf8* string); template NetworkPacket& operator>>(T& value) { - if (BytesRead + sizeof(value) > Size) + if (BytesRead + sizeof(value) > Header.Size) { value = T{}; } @@ -58,8 +63,7 @@ public: template NetworkPacket& operator<<(T value) { T swapped = ByteSwapBE(value); - uint8_t* bytes = reinterpret_cast(&swapped); - Data->insert(Data->end(), bytes, bytes + sizeof(value)); + Write(&swapped, sizeof(T)); return *this; } @@ -68,4 +72,10 @@ public: Write(static_cast(data.GetStream().GetData()), data.GetStream().GetLength()); return *this; } + +public: + PacketHeader Header{}; + std::vector Data; + size_t BytesTransferred = 0; + size_t BytesRead = 0; }; diff --git a/src/openrct2/network/NetworkPlayer.h b/src/openrct2/network/NetworkPlayer.h index a1cff4d91f..1663dd618a 100644 --- a/src/openrct2/network/NetworkPlayer.h +++ b/src/openrct2/network/NetworkPlayer.h @@ -17,7 +17,7 @@ #include #include -class NetworkPacket; +struct NetworkPacket; class NetworkPlayer final {