From 224efdbbdf8680de3823e733a7e303f61234b821 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 2 Aug 2020 22:07:47 +0200 Subject: [PATCH 1/7] Simplify usage of NetworkPacket::Data --- src/openrct2/network/NetworkConnection.cpp | 18 ++++++------ src/openrct2/network/NetworkPacket.cpp | 32 ++++++++++------------ src/openrct2/network/NetworkPacket.h | 12 ++++---- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index 270e68098b..74b09886cc 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -51,13 +51,13 @@ int32_t NetworkConnection::ReadPacket() { return NETWORK_READPACKET_DISCONNECTED; } - InboundPacket.Data->resize(InboundPacket.Size); + InboundPacket.Data.resize(InboundPacket.Size); } } else { // read packet data - if (InboundPacket.Data->capacity() > 0) + if (InboundPacket.Data.capacity() > 0) { void* buffer = &InboundPacket.GetData()[InboundPacket.BytesTransferred - sizeof(InboundPacket.Size)]; size_t bufferLength = sizeof(InboundPacket.Size) + InboundPacket.Size - InboundPacket.BytesTransferred; @@ -88,7 +88,7 @@ bool NetworkConnection::SendPacket(NetworkPacket& packet) std::vector tosend; tosend.reserve(sizeof(sizen) + packet.Size); tosend.insert(tosend.end(), reinterpret_cast(&sizen), reinterpret_cast(&sizen) + sizeof(sizen)); - tosend.insert(tosend.end(), packet.Data->begin(), packet.Data->end()); + tosend.insert(tosend.end(), packet.Data.begin(), packet.Data.end()); const void* buffer = &tosend[packet.BytesTransferred]; size_t bufferSize = tosend.size() - packet.BytesTransferred; @@ -106,15 +106,15 @@ bool NetworkConnection::SendPacket(NetworkPacket& packet) return sendComplete; } -void NetworkConnection::QueuePacket(std::unique_ptr packet, bool front) +void NetworkConnection::QueuePacket(NetworkPacket&& packet, bool front) { - if (AuthStatus == NETWORK_AUTH_OK || !packet->CommandRequiresAuth()) + if (AuthStatus == NETWORK_AUTH_OK || !packet.CommandRequiresAuth()) { - packet->Size = static_cast(packet->Data->size()); + packet.Size = static_cast(packet.Data.size()); if (front) { // If the first packet was already partially sent add new packet to second position - if (!_outboundPackets.empty() && _outboundPackets.front()->BytesTransferred > 0) + if (!_outboundPackets.empty() && _outboundPackets.front().BytesTransferred > 0) { auto it = _outboundPackets.begin(); it++; // Second position @@ -134,9 +134,9 @@ void NetworkConnection::QueuePacket(std::unique_ptr packet, bool void NetworkConnection::SendQueuedPackets() { - while (!_outboundPackets.empty() && SendPacket(*_outboundPackets.front())) + while (!_outboundPackets.empty() && SendPacket(_outboundPackets.front())) { - _outboundPackets.remove(_outboundPackets.front()); + _outboundPackets.pop_front(); } } diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index e1e4d860c6..59a8d0865c 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -15,35 +15,32 @@ # include -std::unique_ptr NetworkPacket::Allocate() -{ - return std::make_unique(); -} - -std::unique_ptr NetworkPacket::Duplicate(NetworkPacket& packet) -{ - return std::make_unique(packet); -} - uint8_t* NetworkPacket::GetData() { - return &(*Data)[0]; + return Data.data(); +} + +const uint8_t* NetworkPacket::GetData() const +{ + return Data.data(); } NetworkCommand NetworkPacket::GetCommand() const { - if (Data->size() < sizeof(uint32_t)) + if (Data.size() < sizeof(uint32_t)) return NetworkCommand::Invalid; - const uint32_t commandId = ByteSwapBE(*reinterpret_cast(&(*Data)[0])); - return static_cast(commandId); + uint32_t commandId = 0; + std::memcpy(&commandId, GetData(), sizeof(commandId)); + + return static_cast(ByteSwapBE(commandId)); } void NetworkPacket::Clear() { BytesTransferred = 0; BytesRead = 0; - Data->clear(); + Data.clear(); } bool NetworkPacket::CommandRequiresAuth() @@ -63,9 +60,10 @@ bool NetworkPacket::CommandRequiresAuth() } } -void NetworkPacket::Write(const uint8_t* bytes, size_t size) +void NetworkPacket::Write(const void* bytes, size_t size) { - Data->insert(Data->end(), bytes, bytes + size); + const uint8_t* src = reinterpret_cast(bytes); + Data.insert(Data.end(), src, src + size); } void NetworkPacket::WriteString(const utf8* string) diff --git a/src/openrct2/network/NetworkPacket.h b/src/openrct2/network/NetworkPacket.h index 507082e81e..6f8bd3fd9b 100644 --- a/src/openrct2/network/NetworkPacket.h +++ b/src/openrct2/network/NetworkPacket.h @@ -20,14 +20,13 @@ class NetworkPacket final { public: uint16_t Size = 0; - std::shared_ptr> Data = std::make_shared>(); + std::vector Data; size_t BytesTransferred = 0; size_t BytesRead = 0; - static std::unique_ptr Allocate(); - static std::unique_ptr Duplicate(NetworkPacket& packet); - uint8_t* GetData(); + const uint8_t* GetData() const; + NetworkCommand GetCommand() const; void Clear(); @@ -36,7 +35,7 @@ public: const uint8_t* Read(size_t size); const utf8* ReadString(); - void Write(const uint8_t* bytes, size_t size); + void Write(const void* bytes, size_t size); void WriteString(const utf8* string); template NetworkPacket& operator>>(T& value) @@ -58,8 +57,7 @@ public: template NetworkPacket& operator<<(T value) { T swapped = ByteSwapBE(value); - uint8_t* bytes = reinterpret_cast(&swapped); - Data->insert(Data->end(), bytes, bytes + sizeof(value)); + Write(&swapped, sizeof(T)); return *this; } From 313b8a751831d6085d25816f99f28210a6dc0703 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 2 Aug 2020 22:08:56 +0200 Subject: [PATCH 2/7] Remove std::unique_ptr and use move semantics instead --- src/openrct2/network/NetworkBase.cpp | 216 ++++++++++++----------- src/openrct2/network/NetworkBase.h | 2 +- src/openrct2/network/NetworkConnection.h | 12 +- 3 files changed, 119 insertions(+), 111 deletions(-) diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index 4010f7a6ad..b807a8e7a5 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -721,7 +721,7 @@ const char* NetworkBase::FormatChat(NetworkPlayer* fromplayer, const char* text) return formatted; } -void NetworkBase::SendPacketToClients(NetworkPacket& packet, bool front, bool gameCmd) +void NetworkBase::SendPacketToClients(const NetworkPacket& packet, bool front, bool gameCmd) { for (auto& client_connection : client_connection_list) { @@ -742,7 +742,8 @@ void NetworkBase::SendPacketToClients(NetworkPacket& packet, bool front, bool ga continue; } } - client_connection->QueuePacket(NetworkPacket::Duplicate(packet), front); + auto packetCopy = packet; + client_connection->QueuePacket(std::move(packetCopy), front); } } @@ -1207,16 +1208,17 @@ void NetworkBase::Client_Send_RequestGameState(uint32_t tick) } log_verbose("Requesting gamestate from server for tick %u", tick); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::RequestGameState) << tick; + + NetworkPacket packet; + packet << static_cast(NetworkCommand::RequestGameState) << tick; _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Client_Send_TOKEN() { log_verbose("requesting token"); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Token); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Token); _serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED; _serverConnection->QueuePacket(std::move(packet)); } @@ -1224,15 +1226,15 @@ void NetworkBase::Client_Send_TOKEN() void NetworkBase::Client_Send_AUTH( const std::string& name, const std::string& password, const std::string& pubkey, const std::vector& signature) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Auth); - packet->WriteString(network_get_version().c_str()); - packet->WriteString(name.c_str()); - packet->WriteString(password.c_str()); - packet->WriteString(pubkey.c_str()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Auth); + packet.WriteString(network_get_version().c_str()); + packet.WriteString(name.c_str()); + packet.WriteString(password.c_str()); + packet.WriteString(pubkey.c_str()); assert(signature.size() <= static_cast(UINT32_MAX)); - *packet << static_cast(signature.size()); - packet->Write(signature.data(), signature.size()); + packet << static_cast(signature.size()); + packet.Write(signature.data(), signature.size()); _serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED; _serverConnection->QueuePacket(std::move(packet)); } @@ -1240,21 +1242,21 @@ void NetworkBase::Client_Send_AUTH( void NetworkBase::Client_Send_MAPREQUEST(const std::vector& objects) { log_verbose("client requests %u objects", uint32_t(objects.size())); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::MapRequest) << static_cast(objects.size()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::MapRequest) << static_cast(objects.size()); for (const auto& object : objects) { log_verbose("client requests object %s", object.c_str()); - packet->Write(reinterpret_cast(object.c_str()), 8); + packet.Write(reinterpret_cast(object.c_str()), 8); } _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_TOKEN(NetworkConnection& connection) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Token) << static_cast(connection.Challenge.size()); - packet->Write(connection.Challenge.data(), connection.Challenge.size()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Token) << static_cast(connection.Challenge.size()); + packet.Write(connection.Challenge.data(), connection.Challenge.size()); connection.QueuePacket(std::move(packet)); } @@ -1265,9 +1267,9 @@ void NetworkBase::Server_Send_OBJECTS_LIST( if (objects.empty()) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::ObjectsList) << static_cast(0) - << static_cast(objects.size()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::ObjectsList) << static_cast(0) + << static_cast(objects.size()); connection.QueuePacket(std::move(packet)); } @@ -1277,13 +1279,13 @@ void NetworkBase::Server_Send_OBJECTS_LIST( { const auto* object = objects[i]; - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::ObjectsList) << static_cast(i) - << static_cast(objects.size()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::ObjectsList) << static_cast(i) + << static_cast(objects.size()); log_verbose("Object %.8s (checksum %x)", object->ObjectEntry.name, object->ObjectEntry.checksum); - packet->Write(reinterpret_cast(object->ObjectEntry.name), 8); - *packet << object->ObjectEntry.checksum << object->ObjectEntry.flags; + packet.Write(reinterpret_cast(object->ObjectEntry.name), 8); + packet << object->ObjectEntry.checksum << object->ObjectEntry.flags; connection.QueuePacket(std::move(packet)); } @@ -1292,8 +1294,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST( void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Scripts); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Scripts); # ifdef ENABLE_SCRIPTING using namespace OpenRCT2::Scripting; @@ -1310,18 +1312,18 @@ void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const } log_verbose("Server sends %u scripts", pluginsToSend.size()); - *packet << static_cast(pluginsToSend.size()); + packet << static_cast(pluginsToSend.size()); for (const auto& plugin : pluginsToSend) { const auto& metadata = plugin->GetMetadata(); log_verbose("Script %s", metadata.Name.c_str()); const auto& code = plugin->GetCode(); - *packet << static_cast(code.size()); - packet->Write(reinterpret_cast(code.c_str()), code.size()); + packet << static_cast(code.size()); + packet.Write(reinterpret_cast(code.c_str()), code.size()); } # else - *packet << static_cast(0); + packet << static_cast(0); # endif connection.QueuePacket(std::move(packet)); } @@ -1330,8 +1332,8 @@ void NetworkBase::Client_Send_HEARTBEAT(NetworkConnection& connection) const { log_verbose("Sending heartbeat"); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Heartbeat); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Heartbeat); connection.QueuePacket(std::move(packet)); } @@ -1364,11 +1366,11 @@ void NetworkBase::Server_Send_AUTH(NetworkConnection& connection) { new_playerid = connection.Player->Id; } - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Auth) << static_cast(connection.AuthStatus) << new_playerid; + NetworkPacket packet; + packet << static_cast(NetworkCommand::Auth) << static_cast(connection.AuthStatus) << new_playerid; if (connection.AuthStatus == NETWORK_AUTH_BADVERSION) { - packet->WriteString(network_get_version().c_str()); + packet.WriteString(network_get_version().c_str()); } connection.QueuePacket(std::move(packet)); if (connection.AuthStatus != NETWORK_AUTH_OK && connection.AuthStatus != NETWORK_AUTH_REQUIREPASSWORD) @@ -1408,16 +1410,16 @@ void NetworkBase::Server_Send_MAP(NetworkConnection* connection) for (size_t i = 0; i < out_size; i += chunksize) { size_t datasize = std::min(chunksize, out_size - i); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Map) << static_cast(out_size) << static_cast(i); - packet->Write(&header[i], datasize); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Map) << static_cast(out_size) << static_cast(i); + packet.Write(&header[i], datasize); if (connection) { connection->QueuePacket(std::move(packet)); } else { - SendPacketToClients(*packet); + SendPacketToClients(packet); } } free(header); @@ -1478,22 +1480,22 @@ uint8_t* NetworkBase::save_for_network(size_t& out_size, const std::vector packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Chat); - packet->WriteString(text); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Chat); + packet.WriteString(text); _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& playerIds) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Chat); - packet->WriteString(text); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Chat); + packet.WriteString(text); if (playerIds.empty()) { // Empty players / default value means send to all players - SendPacketToClients(*packet); + SendPacketToClients(packet); } else { @@ -1502,7 +1504,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& auto conn = GetPlayerConnection(playerId); if (conn != nullptr && !conn->IsDisconnected) { - conn->QueuePacket(NetworkPacket::Duplicate(*packet)); + conn->QueuePacket(packet); } } } @@ -1510,7 +1512,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action) { - std::unique_ptr packet(NetworkPacket::Allocate()); + NetworkPacket packet; uint32_t networkId = 0; networkId = ++_actionId; @@ -1525,26 +1527,26 @@ void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action) DataSerialiser stream(true); action->Serialise(stream); - *packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; + packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_GAME_ACTION(const GameAction* action) { - std::unique_ptr packet(NetworkPacket::Allocate()); + NetworkPacket packet; DataSerialiser stream(true); action->Serialise(stream); - *packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; + packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_TICK() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Tick) << gCurrentTicks << scenario_rand_state().s0; + NetworkPacket packet; + packet << static_cast(NetworkCommand::Tick) << gCurrentTicks << scenario_rand_state().s0; uint32_t flags = 0; // Simple counter which limits how often a sprite checksum gets sent. // This can get somewhat expensive, so we don't want to push it every tick in release, @@ -1558,75 +1560,75 @@ void NetworkBase::Server_Send_TICK() } // Send flags always, so we can understand packet structure on the other end, // and allow for some expansion. - *packet << flags; + packet << flags; if (flags & NETWORK_TICK_FLAG_CHECKSUMS) { rct_sprite_checksum checksum = sprite_checksum(); - packet->WriteString(checksum.ToString().c_str()); + packet.WriteString(checksum.ToString().c_str()); } - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_PLAYERINFO(int32_t playerId) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::PlayerInfo) << gCurrentTicks; + NetworkPacket packet; + packet << static_cast(NetworkCommand::PlayerInfo) << gCurrentTicks; auto* player = GetPlayerByID(playerId); if (player == nullptr) return; - player->Write(*packet); - SendPacketToClients(*packet); + player->Write(packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_PLAYERLIST() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::PlayerList) << gCurrentTicks << static_cast(player_list.size()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::PlayerList) << gCurrentTicks << static_cast(player_list.size()); for (auto& player : player_list) { - player->Write(*packet); + player->Write(packet); } - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Client_Send_PING() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Ping); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Ping); _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_PING() { last_ping_sent_time = platform_get_ticks(); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Ping); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Ping); for (auto& client_connection : client_connection_list) { client_connection->PingTime = platform_get_ticks(); } - SendPacketToClients(*packet, true); + SendPacketToClients(packet, true); } void NetworkBase::Server_Send_PINGLIST() { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::PingList) << static_cast(player_list.size()); + NetworkPacket packet; + packet << static_cast(NetworkCommand::PingList) << static_cast(player_list.size()); for (auto& player : player_list) { - *packet << player->Id << player->Ping; + packet << player->Id << player->Ping; } - SendPacketToClients(*packet); + SendPacketToClients(packet); } void NetworkBase::Server_Send_SETDISCONNECTMSG(NetworkConnection& connection, const char* msg) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::DisconnectMessage); - packet->WriteString(msg); + NetworkPacket packet; + packet << static_cast(NetworkCommand::DisconnectMessage); + packet.WriteString(msg); connection.QueuePacket(std::move(packet)); } @@ -1646,8 +1648,8 @@ json_t* NetworkBase::GetServerInfoAsJson() const void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::GameInfo); + NetworkPacket packet; + packet << static_cast(NetworkCommand::GameInfo); # ifndef DISABLE_HTTP json_t* obj = GetServerInfoAsJson(); @@ -1658,8 +1660,8 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) json_object_set_new(jsonProvider, "website", json_string(gConfigNetwork.provider_website.c_str())); json_object_set_new(obj, "provider", jsonProvider); - packet->WriteString(json_dumps(obj, 0)); - *packet << _serverState.gamestateSnapshotsEnabled; + packet.WriteString(json_dumps(obj, 0)); + packet << _serverState.gamestateSnapshotsEnabled; json_decref(obj); # endif @@ -1668,39 +1670,39 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) void NetworkBase::Server_Send_SHOWERROR(NetworkConnection& connection, rct_string_id title, rct_string_id message) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::ShowError) << title << message; + NetworkPacket packet; + packet << static_cast(NetworkCommand::ShowError) << title << message; connection.QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_GROUPLIST(NetworkConnection& connection) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::GroupList) << static_cast(group_list.size()) << default_group; + NetworkPacket packet; + packet << static_cast(NetworkCommand::GroupList) << static_cast(group_list.size()) << default_group; for (auto& group : group_list) { - group->Write(*packet); + group->Write(packet); } connection.QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_EVENT_PLAYER_JOINED(const char* playerName) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Event); - *packet << static_cast(SERVER_EVENT_PLAYER_JOINED); - packet->WriteString(playerName); - SendPacketToClients(*packet); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Event); + packet << static_cast(SERVER_EVENT_PLAYER_JOINED); + packet.WriteString(playerName); + SendPacketToClients(packet); } void NetworkBase::Server_Send_EVENT_PLAYER_DISCONNECTED(const char* playerName, const char* reason) { - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::Event); - *packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); - packet->WriteString(playerName); - packet->WriteString(reason); - SendPacketToClients(*packet); + NetworkPacket packet; + packet << static_cast(NetworkCommand::Event); + packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); + packet.WriteString(playerName); + packet.WriteString(reason); + SendPacketToClients(packet); } bool NetworkBase::ProcessConnection(NetworkConnection& connection) @@ -2240,11 +2242,11 @@ void NetworkBase::Server_Handle_REQUEST_GAMESTATE(NetworkConnection& connection, dataSize = snapshotMemory.GetLength() - bytesSent; } - std::unique_ptr gameStateChunk(NetworkPacket::Allocate()); - *gameStateChunk << static_cast(NetworkCommand::GameState) << tick << length << bytesSent << dataSize; - gameStateChunk->Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); + NetworkPacket packetGameStateChunk; + packetGameStateChunk << static_cast(NetworkCommand::GameState) << tick << length << bytesSent << dataSize; + packetGameStateChunk.Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); - connection.QueuePacket(std::move(gameStateChunk)); + connection.QueuePacket(std::move(packetGameStateChunk)); bytesSent += dataSize; } @@ -3208,8 +3210,8 @@ void NetworkBase::Client_Handle_EVENT([[maybe_unused]] NetworkConnection& connec void NetworkBase::Client_Send_GAMEINFO() { log_verbose("requesting gameinfo"); - std::unique_ptr packet(NetworkPacket::Allocate()); - *packet << static_cast(NetworkCommand::GameInfo); + NetworkPacket packet; + packet << static_cast(NetworkCommand::GameInfo); _serverConnection->QueuePacket(std::move(packet)); } diff --git a/src/openrct2/network/NetworkBase.h b/src/openrct2/network/NetworkBase.h index ad8eed136e..0f2f1fd0b5 100644 --- a/src/openrct2/network/NetworkBase.h +++ b/src/openrct2/network/NetworkBase.h @@ -113,7 +113,7 @@ public: // Client void ProcessPlayerInfo(); void ProcessDisconnectedClients(); static const char* FormatChat(NetworkPlayer* fromplayer, const char* text); - void SendPacketToClients(NetworkPacket& packet, bool front = false, bool gameCmd = false); + void SendPacketToClients(const NetworkPacket& packet, bool front = false, bool gameCmd = false); bool CheckSRAND(uint32_t tick, uint32_t srand0); bool CheckDesynchronizaton(); void RequestStateSnapshot(); diff --git a/src/openrct2/network/NetworkConnection.h b/src/openrct2/network/NetworkConnection.h index 0dcd379e73..bdfb8697cd 100644 --- a/src/openrct2/network/NetworkConnection.h +++ b/src/openrct2/network/NetworkConnection.h @@ -16,7 +16,7 @@ # include "NetworkTypes.h" # include "Socket.h" -# include +# include # include # include @@ -41,7 +41,13 @@ public: ~NetworkConnection(); int32_t ReadPacket(); - void QueuePacket(std::unique_ptr packet, bool front = false); + void QueuePacket(NetworkPacket&& packet, bool front = false); + void QueuePacket(const NetworkPacket& packet, bool front = false) + { + auto copy = packet; + return QueuePacket(std::move(copy)); + } + void SendQueuedPackets(); void ResetLastPacketTime(); bool ReceivedPacketRecently(); @@ -51,7 +57,7 @@ public: void SetLastDisconnectReason(const rct_string_id string_id, void* args = nullptr); private: - std::list> _outboundPackets; + std::deque _outboundPackets; uint32_t _lastPacketTime = 0; utf8* _lastDisconnectReason = nullptr; From 21f63c1010db6d8680dcf451924348b288caf780 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 2 Aug 2020 22:31:00 +0200 Subject: [PATCH 3/7] Pass NetworkCommand type as constructor to simplify code --- src/openrct2/network/NetworkBase.cpp | 107 +++++++++++-------------- src/openrct2/network/NetworkPacket.cpp | 5 ++ src/openrct2/network/NetworkPacket.h | 6 +- src/openrct2/network/NetworkPlayer.h | 2 +- 4 files changed, 56 insertions(+), 64 deletions(-) diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index b807a8e7a5..2ce80a79a6 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -1209,16 +1209,15 @@ void NetworkBase::Client_Send_RequestGameState(uint32_t tick) log_verbose("Requesting gamestate from server for tick %u", tick); - NetworkPacket packet; - packet << static_cast(NetworkCommand::RequestGameState) << tick; + NetworkPacket packet(NetworkCommand::RequestGameState); + packet << tick; _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Client_Send_TOKEN() { log_verbose("requesting token"); - NetworkPacket packet; - packet << static_cast(NetworkCommand::Token); + NetworkPacket packet(NetworkCommand::Token); _serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED; _serverConnection->QueuePacket(std::move(packet)); } @@ -1226,8 +1225,7 @@ void NetworkBase::Client_Send_TOKEN() void NetworkBase::Client_Send_AUTH( const std::string& name, const std::string& password, const std::string& pubkey, const std::vector& signature) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Auth); + NetworkPacket packet(NetworkCommand::Auth); packet.WriteString(network_get_version().c_str()); packet.WriteString(name.c_str()); packet.WriteString(password.c_str()); @@ -1242,8 +1240,8 @@ void NetworkBase::Client_Send_AUTH( void NetworkBase::Client_Send_MAPREQUEST(const std::vector& objects) { log_verbose("client requests %u objects", uint32_t(objects.size())); - NetworkPacket packet; - packet << static_cast(NetworkCommand::MapRequest) << static_cast(objects.size()); + NetworkPacket packet(NetworkCommand::MapRequest); + packet << static_cast(objects.size()); for (const auto& object : objects) { log_verbose("client requests object %s", object.c_str()); @@ -1254,8 +1252,8 @@ void NetworkBase::Client_Send_MAPREQUEST(const std::vector& objects void NetworkBase::Server_Send_TOKEN(NetworkConnection& connection) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Token) << static_cast(connection.Challenge.size()); + NetworkPacket packet(NetworkCommand::Token); + packet << static_cast(connection.Challenge.size()); packet.Write(connection.Challenge.data(), connection.Challenge.size()); connection.QueuePacket(std::move(packet)); } @@ -1267,9 +1265,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST( if (objects.empty()) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::ObjectsList) << static_cast(0) - << static_cast(objects.size()); + NetworkPacket packet(NetworkCommand::ObjectsList); + packet << static_cast(0) << static_cast(objects.size()); connection.QueuePacket(std::move(packet)); } @@ -1279,9 +1276,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST( { const auto* object = objects[i]; - NetworkPacket packet; - packet << static_cast(NetworkCommand::ObjectsList) << static_cast(i) - << static_cast(objects.size()); + NetworkPacket packet(NetworkCommand::ObjectsList); + packet << static_cast(i) << static_cast(objects.size()); log_verbose("Object %.8s (checksum %x)", object->ObjectEntry.name, object->ObjectEntry.checksum); packet.Write(reinterpret_cast(object->ObjectEntry.name), 8); @@ -1294,8 +1290,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST( void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Scripts); + NetworkPacket packet(NetworkCommand::Scripts); + # ifdef ENABLE_SCRIPTING using namespace OpenRCT2::Scripting; @@ -1332,9 +1328,7 @@ void NetworkBase::Client_Send_HEARTBEAT(NetworkConnection& connection) const { log_verbose("Sending heartbeat"); - NetworkPacket packet; - packet << static_cast(NetworkCommand::Heartbeat); - + NetworkPacket packet(NetworkCommand::Heartbeat); connection.QueuePacket(std::move(packet)); } @@ -1366,8 +1360,8 @@ void NetworkBase::Server_Send_AUTH(NetworkConnection& connection) { new_playerid = connection.Player->Id; } - NetworkPacket packet; - packet << static_cast(NetworkCommand::Auth) << static_cast(connection.AuthStatus) << new_playerid; + NetworkPacket packet(NetworkCommand::Auth); + packet << static_cast(connection.AuthStatus) << new_playerid; if (connection.AuthStatus == NETWORK_AUTH_BADVERSION) { packet.WriteString(network_get_version().c_str()); @@ -1410,8 +1404,8 @@ void NetworkBase::Server_Send_MAP(NetworkConnection* connection) for (size_t i = 0; i < out_size; i += chunksize) { size_t datasize = std::min(chunksize, out_size - i); - NetworkPacket packet; - packet << static_cast(NetworkCommand::Map) << static_cast(out_size) << static_cast(i); + NetworkPacket packet(NetworkCommand::Map); + packet << static_cast(out_size) << static_cast(i); packet.Write(&header[i], datasize); if (connection) { @@ -1480,16 +1474,14 @@ uint8_t* NetworkBase::save_for_network(size_t& out_size, const std::vector(NetworkCommand::Chat); + NetworkPacket packet(NetworkCommand::Chat); packet.WriteString(text); _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& playerIds) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Chat); + NetworkPacket packet(NetworkCommand::Chat); packet.WriteString(text); if (playerIds.empty()) @@ -1512,7 +1504,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector& void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action) { - NetworkPacket packet; + NetworkPacket packet(NetworkCommand::GameAction); uint32_t networkId = 0; networkId = ++_actionId; @@ -1527,26 +1519,26 @@ void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action) DataSerialiser stream(true); action->Serialise(stream); - packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; + packet << gCurrentTicks << action->GetType() << stream; _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_GAME_ACTION(const GameAction* action) { - NetworkPacket packet; + NetworkPacket packet(NetworkCommand::GameAction); DataSerialiser stream(true); action->Serialise(stream); - packet << static_cast(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream; + packet << gCurrentTicks << action->GetType() << stream; SendPacketToClients(packet); } void NetworkBase::Server_Send_TICK() { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Tick) << gCurrentTicks << scenario_rand_state().s0; + NetworkPacket packet(NetworkCommand::Tick); + packet << gCurrentTicks << scenario_rand_state().s0; uint32_t flags = 0; // Simple counter which limits how often a sprite checksum gets sent. // This can get somewhat expensive, so we don't want to push it every tick in release, @@ -1572,8 +1564,8 @@ void NetworkBase::Server_Send_TICK() void NetworkBase::Server_Send_PLAYERINFO(int32_t playerId) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::PlayerInfo) << gCurrentTicks; + NetworkPacket packet(NetworkCommand::PlayerInfo); + packet << gCurrentTicks; auto* player = GetPlayerByID(playerId); if (player == nullptr) @@ -1585,8 +1577,8 @@ void NetworkBase::Server_Send_PLAYERINFO(int32_t playerId) void NetworkBase::Server_Send_PLAYERLIST() { - NetworkPacket packet; - packet << static_cast(NetworkCommand::PlayerList) << gCurrentTicks << static_cast(player_list.size()); + NetworkPacket packet(NetworkCommand::PlayerList); + packet << gCurrentTicks << static_cast(player_list.size()); for (auto& player : player_list) { player->Write(packet); @@ -1596,16 +1588,14 @@ void NetworkBase::Server_Send_PLAYERLIST() void NetworkBase::Client_Send_PING() { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Ping); + NetworkPacket packet(NetworkCommand::Ping); _serverConnection->QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_PING() { last_ping_sent_time = platform_get_ticks(); - NetworkPacket packet; - packet << static_cast(NetworkCommand::Ping); + NetworkPacket packet(NetworkCommand::Ping); for (auto& client_connection : client_connection_list) { client_connection->PingTime = platform_get_ticks(); @@ -1615,8 +1605,8 @@ void NetworkBase::Server_Send_PING() void NetworkBase::Server_Send_PINGLIST() { - NetworkPacket packet; - packet << static_cast(NetworkCommand::PingList) << static_cast(player_list.size()); + NetworkPacket packet(NetworkCommand::PingList); + packet << static_cast(player_list.size()); for (auto& player : player_list) { packet << player->Id << player->Ping; @@ -1626,8 +1616,7 @@ void NetworkBase::Server_Send_PINGLIST() void NetworkBase::Server_Send_SETDISCONNECTMSG(NetworkConnection& connection, const char* msg) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::DisconnectMessage); + NetworkPacket packet(NetworkCommand::DisconnectMessage); packet.WriteString(msg); connection.QueuePacket(std::move(packet)); } @@ -1648,8 +1637,7 @@ json_t* NetworkBase::GetServerInfoAsJson() const void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::GameInfo); + NetworkPacket packet(NetworkCommand::GameInfo); # ifndef DISABLE_HTTP json_t* obj = GetServerInfoAsJson(); @@ -1670,15 +1658,15 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection) void NetworkBase::Server_Send_SHOWERROR(NetworkConnection& connection, rct_string_id title, rct_string_id message) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::ShowError) << title << message; + NetworkPacket packet(NetworkCommand::ShowError); + packet << title << message; connection.QueuePacket(std::move(packet)); } void NetworkBase::Server_Send_GROUPLIST(NetworkConnection& connection) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::GroupList) << static_cast(group_list.size()) << default_group; + NetworkPacket packet(NetworkCommand::GroupList); + packet << static_cast(group_list.size()) << default_group; for (auto& group : group_list) { group->Write(packet); @@ -1688,8 +1676,7 @@ void NetworkBase::Server_Send_GROUPLIST(NetworkConnection& connection) void NetworkBase::Server_Send_EVENT_PLAYER_JOINED(const char* playerName) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Event); + NetworkPacket packet(NetworkCommand::Event); packet << static_cast(SERVER_EVENT_PLAYER_JOINED); packet.WriteString(playerName); SendPacketToClients(packet); @@ -1697,8 +1684,7 @@ void NetworkBase::Server_Send_EVENT_PLAYER_JOINED(const char* playerName) void NetworkBase::Server_Send_EVENT_PLAYER_DISCONNECTED(const char* playerName, const char* reason) { - NetworkPacket packet; - packet << static_cast(NetworkCommand::Event); + NetworkPacket packet(NetworkCommand::Event); packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); packet.WriteString(playerName); packet.WriteString(reason); @@ -2242,8 +2228,8 @@ void NetworkBase::Server_Handle_REQUEST_GAMESTATE(NetworkConnection& connection, dataSize = snapshotMemory.GetLength() - bytesSent; } - NetworkPacket packetGameStateChunk; - packetGameStateChunk << static_cast(NetworkCommand::GameState) << tick << length << bytesSent << dataSize; + NetworkPacket packetGameStateChunk(NetworkCommand::GameState); + packetGameStateChunk << tick << length << bytesSent << dataSize; packetGameStateChunk.Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); connection.QueuePacket(std::move(packetGameStateChunk)); @@ -3210,8 +3196,7 @@ void NetworkBase::Client_Handle_EVENT([[maybe_unused]] NetworkConnection& connec void NetworkBase::Client_Send_GAMEINFO() { log_verbose("requesting gameinfo"); - NetworkPacket packet; - packet << static_cast(NetworkCommand::GameInfo); + NetworkPacket packet(NetworkCommand::GameInfo); _serverConnection->QueuePacket(std::move(packet)); } diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index 59a8d0865c..bf285329e0 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -15,6 +15,11 @@ # include +NetworkPacket::NetworkPacket(NetworkCommand id) +{ + *this << static_cast::type>(id); +} + uint8_t* NetworkPacket::GetData() { return Data.data(); diff --git a/src/openrct2/network/NetworkPacket.h b/src/openrct2/network/NetworkPacket.h index 6f8bd3fd9b..bd376f993d 100644 --- a/src/openrct2/network/NetworkPacket.h +++ b/src/openrct2/network/NetworkPacket.h @@ -16,14 +16,16 @@ #include #include -class NetworkPacket final +struct NetworkPacket final { -public: uint16_t Size = 0; std::vector Data; size_t BytesTransferred = 0; size_t BytesRead = 0; + NetworkPacket() = default; + NetworkPacket(NetworkCommand id); + uint8_t* GetData(); const uint8_t* GetData() const; diff --git a/src/openrct2/network/NetworkPlayer.h b/src/openrct2/network/NetworkPlayer.h index a1cff4d91f..1663dd618a 100644 --- a/src/openrct2/network/NetworkPlayer.h +++ b/src/openrct2/network/NetworkPlayer.h @@ -17,7 +17,7 @@ #include #include -class NetworkPacket; +struct NetworkPacket; class NetworkPlayer final { From 28dcaeae2fce1945d09271bd5414cfff9162d0c4 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 2 Aug 2020 23:34:02 +0200 Subject: [PATCH 4/7] Refactor reading/writing packets --- src/openrct2/network/NetworkBase.cpp | 11 +-- src/openrct2/network/NetworkConnection.cpp | 89 ++++++++++++++-------- src/openrct2/network/NetworkPacket.cpp | 14 +--- src/openrct2/network/NetworkPacket.h | 20 +++-- 4 files changed, 79 insertions(+), 55 deletions(-) diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index 2ce80a79a6..f792d203e2 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -1736,11 +1736,8 @@ bool NetworkBase::ProcessConnection(NetworkConnection& connection) void NetworkBase::ProcessPacket(NetworkConnection& connection, NetworkPacket& packet) { - std::underlying_type::type command; - packet >> command; - const auto& handlerList = GetMode() == NETWORK_MODE_SERVER ? server_command_handlers : client_command_handlers; - auto it = handlerList.find(static_cast(command)); + auto it = handlerList.find(packet.GetCommand()); if (it != handlerList.end()) { auto commandHandler = it->second; @@ -2650,7 +2647,7 @@ void NetworkBase::Client_Handle_MAP([[maybe_unused]] NetworkConnection& connecti { uint32_t size, offset; packet >> size >> offset; - int32_t chunksize = static_cast(packet.Size - packet.BytesRead); + int32_t chunksize = static_cast(packet.Header.Size - packet.BytesRead); if (chunksize <= 0) { return; @@ -2925,7 +2922,7 @@ void NetworkBase::Client_Handle_GAME_ACTION([[maybe_unused]] NetworkConnection& packet >> tick >> actionType; MemoryStream stream; - size_t size = packet.Size - packet.BytesRead; + const size_t size = packet.Header.Size - packet.BytesRead; stream.WriteArray(packet.Read(size), size); stream.SetPosition(0); @@ -3015,7 +3012,7 @@ void NetworkBase::Server_Handle_GAME_ACTION(NetworkConnection& connection, Netwo } DataSerialiser stream(false); - size_t size = packet.Size - packet.BytesRead; + const size_t size = packet.Header.Size - packet.BytesRead; stream.GetStream().WriteArray(packet.Read(size), size); stream.GetStream().SetPosition(0); diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index 74b09886cc..1d1a4f1df1 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -18,6 +18,7 @@ # include "network.h" constexpr size_t NETWORK_DISCONNECT_REASON_BUFFER_SIZE = 256; +constexpr size_t NetworkBufferSize = 1024; NetworkConnection::NetworkConnection() { @@ -31,47 +32,61 @@ NetworkConnection::~NetworkConnection() int32_t NetworkConnection::ReadPacket() { - if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Size)) + size_t bytesRead = 0; + + // Read packet header. + auto& header = InboundPacket.Header; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - // read packet size - void* buffer = &(reinterpret_cast(&InboundPacket.Size))[InboundPacket.BytesTransferred]; - size_t bufferLength = sizeof(InboundPacket.Size) - InboundPacket.BytesTransferred; - size_t readBytes; - NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes); + const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred; + + uint8_t* buffer = reinterpret_cast(&InboundPacket.Header); + + NETWORK_READPACKET status = Socket->ReceiveData(buffer, missingLength, &bytesRead); if (status != NETWORK_READPACKET_SUCCESS) { return status; } - InboundPacket.BytesTransferred += readBytes; - if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size)) + InboundPacket.BytesTransferred += bytesRead; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - InboundPacket.Size = Convert::NetworkToHost(InboundPacket.Size); - if (InboundPacket.Size == 0) // Can't have a size 0 packet - { - return NETWORK_READPACKET_DISCONNECTED; - } - InboundPacket.Data.resize(InboundPacket.Size); + // If still not enough data for header, keep waiting. + return NETWORK_READPACKET_MORE_DATA; } + + // Normalise values. + header.Size = Convert::NetworkToHost(header.Size); + header.Id = ByteSwapBE(header.Id); + + // NOTE: For compatibility reasons we need to remove sizeof(Header.Id) from the size + // Previously the Id field was not part of the header rather part of the body + header.Size -= sizeof(header.Id); + + // Fall-through: Read rest of packet. } - else + + // Read packet body. { - // read packet data - if (InboundPacket.Data.capacity() > 0) + const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header)); + + uint8_t buffer[NetworkBufferSize]; + + if (missingLength > 0) { - void* buffer = &InboundPacket.GetData()[InboundPacket.BytesTransferred - sizeof(InboundPacket.Size)]; - size_t bufferLength = sizeof(InboundPacket.Size) + InboundPacket.Size - InboundPacket.BytesTransferred; - size_t readBytes; - NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes); + NETWORK_READPACKET status = Socket->ReceiveData(buffer, std::min(missingLength, NetworkBufferSize), &bytesRead); if (status != NETWORK_READPACKET_SUCCESS) { return status; } - InboundPacket.BytesTransferred += readBytes; + InboundPacket.BytesTransferred += bytesRead; + InboundPacket.Write(buffer, bytesRead); } - if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size) + InboundPacket.Size) + + if (InboundPacket.Data.size() == header.Size) { + // Received complete packet. _lastPacketTime = platform_get_ticks(); RecordPacketStats(InboundPacket, false); @@ -79,26 +94,34 @@ int32_t NetworkConnection::ReadPacket() return NETWORK_READPACKET_SUCCESS; } } + return NETWORK_READPACKET_MORE_DATA; } bool NetworkConnection::SendPacket(NetworkPacket& packet) { - uint16_t sizen = Convert::HostToNetwork(packet.Size); - std::vector tosend; - tosend.reserve(sizeof(sizen) + packet.Size); - tosend.insert(tosend.end(), reinterpret_cast(&sizen), reinterpret_cast(&sizen) + sizeof(sizen)); - tosend.insert(tosend.end(), packet.Data.begin(), packet.Data.end()); + auto header = packet.Header; - const void* buffer = &tosend[packet.BytesTransferred]; - size_t bufferSize = tosend.size() - packet.BytesTransferred; - size_t sent = Socket->SendData(buffer, bufferSize); + std::vector buffer; + buffer.reserve(sizeof(header) + header.Size); + + // NOTE: For compatibility reasons we need to add sizeof(Header.Id) to the size + // Previously the Id field was not part of the header rather part of the body + header.Size += sizeof(header.Id); + header.Size = Convert::HostToNetwork(header.Size); + header.Id = ByteSwapBE(header.Id); + + buffer.insert(buffer.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); + buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end()); + + size_t bufferSize = buffer.size() - packet.BytesTransferred; + size_t sent = Socket->SendData(buffer.data() + packet.BytesTransferred, bufferSize); if (sent > 0) { packet.BytesTransferred += sent; } - bool sendComplete = packet.BytesTransferred == tosend.size(); + bool sendComplete = packet.BytesTransferred == buffer.size(); if (sendComplete) { RecordPacketStats(packet, true); @@ -110,7 +133,7 @@ void NetworkConnection::QueuePacket(NetworkPacket&& packet, bool front) { if (AuthStatus == NETWORK_AUTH_OK || !packet.CommandRequiresAuth()) { - packet.Size = static_cast(packet.Data.size()); + packet.Header.Size = static_cast(packet.Data.size()); if (front) { // If the first packet was already partially sent add new packet to second position diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index bf285329e0..c781fd1978 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -16,8 +16,8 @@ # include NetworkPacket::NetworkPacket(NetworkCommand id) + : Header{ 0, id } { - *this << static_cast::type>(id); } uint8_t* NetworkPacket::GetData() @@ -32,13 +32,7 @@ const uint8_t* NetworkPacket::GetData() const NetworkCommand NetworkPacket::GetCommand() const { - if (Data.size() < sizeof(uint32_t)) - return NetworkCommand::Invalid; - - uint32_t commandId = 0; - std::memcpy(&commandId, GetData(), sizeof(commandId)); - - return static_cast(ByteSwapBE(commandId)); + return Header.Id; } void NetworkPacket::Clear() @@ -78,7 +72,7 @@ void NetworkPacket::WriteString(const utf8* string) const uint8_t* NetworkPacket::Read(size_t size) { - if (BytesRead + size > NetworkPacket::Size) + if (BytesRead + size > Header.Size) { return nullptr; } @@ -94,7 +88,7 @@ const utf8* NetworkPacket::ReadString() { char* str = reinterpret_cast(&GetData()[BytesRead]); char* strend = str; - while (BytesRead < Size && *strend != 0) + while (BytesRead < Header.Size && *strend != 0) { BytesRead++; strend++; diff --git a/src/openrct2/network/NetworkPacket.h b/src/openrct2/network/NetworkPacket.h index bd376f993d..ffa464333a 100644 --- a/src/openrct2/network/NetworkPacket.h +++ b/src/openrct2/network/NetworkPacket.h @@ -16,13 +16,17 @@ #include #include -struct NetworkPacket final +#pragma pack(push, 1) +struct PacketHeader { uint16_t Size = 0; - std::vector Data; - size_t BytesTransferred = 0; - size_t BytesRead = 0; + NetworkCommand Id = NetworkCommand::Invalid; +}; +static_assert(sizeof(PacketHeader) == 6); +#pragma pack(pop) +struct NetworkPacket final +{ NetworkPacket() = default; NetworkPacket(NetworkCommand id); @@ -42,7 +46,7 @@ struct NetworkPacket final template NetworkPacket& operator>>(T& value) { - if (BytesRead + sizeof(value) > Size) + if (BytesRead + sizeof(value) > Header.Size) { value = T{}; } @@ -68,4 +72,10 @@ struct NetworkPacket final Write(static_cast(data.GetStream().GetData()), data.GetStream().GetLength()); return *this; } + +public: + PacketHeader Header{}; + std::vector Data; + size_t BytesTransferred = 0; + size_t BytesRead = 0; }; From 07b343813ab955648bdd760b7dc80e99bff7c900 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 3 Aug 2020 00:46:44 +0200 Subject: [PATCH 5/7] Fix passing missing parameter --- src/openrct2/network/NetworkConnection.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/openrct2/network/NetworkConnection.h b/src/openrct2/network/NetworkConnection.h index bdfb8697cd..77000bf04e 100644 --- a/src/openrct2/network/NetworkConnection.h +++ b/src/openrct2/network/NetworkConnection.h @@ -45,7 +45,7 @@ public: void QueuePacket(const NetworkPacket& packet, bool front = false) { auto copy = packet; - return QueuePacket(std::move(copy)); + return QueuePacket(std::move(copy), front); } void SendQueuedPackets(); From 05a9b271362a4b2f1de0966a5ae7733b939ff6ad Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 3 Aug 2020 01:19:20 +0200 Subject: [PATCH 6/7] Fix ByteSwapBE type safety --- src/openrct2/core/Endianness.h | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/openrct2/core/Endianness.h b/src/openrct2/core/Endianness.h index 22db6e49e1..5bebd494b6 100644 --- a/src/openrct2/core/Endianness.h +++ b/src/openrct2/core/Endianness.h @@ -11,6 +11,9 @@ #include "../common.h" +#include +#include + template struct ByteSwapT { }; @@ -58,6 +61,22 @@ template<> struct ByteSwapT<8> template static T ByteSwapBE(const T& value) { using ByteSwap = ByteSwapT; - typename ByteSwap::UIntType result = ByteSwap::SwapBE(reinterpret_cast(value)); - return *reinterpret_cast(&result); + using UIntType = typename ByteSwap::UIntType; + + if constexpr (std::is_enum_v || std::is_integral_v) + { + auto result = ByteSwap::SwapBE(static_cast(value)); + return static_cast(result); + } + else + { + // Complex type, reinterpret_cast is not safe for this case. + // Create a temporary of size(T) as unsigned type via copy instead. + UIntType temp; + std::memcpy(&temp, &value, sizeof(T)); + auto result = ByteSwap::SwapBE(temp); + T res; + std::memcpy(&res, &result, sizeof(T)); + return res; + } } From e91e68e3ec9c7f2771216fd8c9e63a3ad8020d6a Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Aug 2020 16:19:57 +0200 Subject: [PATCH 7/7] Re-phrase the network compatibility comment --- src/openrct2/network/NetworkConnection.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index 1d1a4f1df1..1f21a8c3b1 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -59,8 +59,8 @@ int32_t NetworkConnection::ReadPacket() header.Size = Convert::NetworkToHost(header.Size); header.Id = ByteSwapBE(header.Id); - // NOTE: For compatibility reasons we need to remove sizeof(Header.Id) from the size - // Previously the Id field was not part of the header rather part of the body + // NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size. + // Previously the Id field was not part of the header rather part of the body. header.Size -= sizeof(header.Id); // Fall-through: Read rest of packet. @@ -105,8 +105,8 @@ bool NetworkConnection::SendPacket(NetworkPacket& packet) std::vector buffer; buffer.reserve(sizeof(header) + header.Size); - // NOTE: For compatibility reasons we need to add sizeof(Header.Id) to the size - // Previously the Id field was not part of the header rather part of the body + // NOTE: For compatibility reasons for the master server we need to add sizeof(Header.Id) to the size. + // Previously the Id field was not part of the header rather part of the body. header.Size += sizeof(header.Id); header.Size = Convert::HostToNetwork(header.Size); header.Id = ByteSwapBE(header.Id);