diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index 2ce80a79a6..f792d203e2 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -1736,11 +1736,8 @@ bool NetworkBase::ProcessConnection(NetworkConnection& connection) void NetworkBase::ProcessPacket(NetworkConnection& connection, NetworkPacket& packet) { - std::underlying_type::type command; - packet >> command; - const auto& handlerList = GetMode() == NETWORK_MODE_SERVER ? server_command_handlers : client_command_handlers; - auto it = handlerList.find(static_cast(command)); + auto it = handlerList.find(packet.GetCommand()); if (it != handlerList.end()) { auto commandHandler = it->second; @@ -2650,7 +2647,7 @@ void NetworkBase::Client_Handle_MAP([[maybe_unused]] NetworkConnection& connecti { uint32_t size, offset; packet >> size >> offset; - int32_t chunksize = static_cast(packet.Size - packet.BytesRead); + int32_t chunksize = static_cast(packet.Header.Size - packet.BytesRead); if (chunksize <= 0) { return; @@ -2925,7 +2922,7 @@ void NetworkBase::Client_Handle_GAME_ACTION([[maybe_unused]] NetworkConnection& packet >> tick >> actionType; MemoryStream stream; - size_t size = packet.Size - packet.BytesRead; + const size_t size = packet.Header.Size - packet.BytesRead; stream.WriteArray(packet.Read(size), size); stream.SetPosition(0); @@ -3015,7 +3012,7 @@ void NetworkBase::Server_Handle_GAME_ACTION(NetworkConnection& connection, Netwo } DataSerialiser stream(false); - size_t size = packet.Size - packet.BytesRead; + const size_t size = packet.Header.Size - packet.BytesRead; stream.GetStream().WriteArray(packet.Read(size), size); stream.GetStream().SetPosition(0); diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index 74b09886cc..1d1a4f1df1 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -18,6 +18,7 @@ # include "network.h" constexpr size_t NETWORK_DISCONNECT_REASON_BUFFER_SIZE = 256; +constexpr size_t NetworkBufferSize = 1024; NetworkConnection::NetworkConnection() { @@ -31,47 +32,61 @@ NetworkConnection::~NetworkConnection() int32_t NetworkConnection::ReadPacket() { - if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Size)) + size_t bytesRead = 0; + + // Read packet header. + auto& header = InboundPacket.Header; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - // read packet size - void* buffer = &(reinterpret_cast(&InboundPacket.Size))[InboundPacket.BytesTransferred]; - size_t bufferLength = sizeof(InboundPacket.Size) - InboundPacket.BytesTransferred; - size_t readBytes; - NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes); + const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred; + + uint8_t* buffer = reinterpret_cast(&InboundPacket.Header); + + NETWORK_READPACKET status = Socket->ReceiveData(buffer, missingLength, &bytesRead); if (status != NETWORK_READPACKET_SUCCESS) { return status; } - InboundPacket.BytesTransferred += readBytes; - if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size)) + InboundPacket.BytesTransferred += bytesRead; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - InboundPacket.Size = Convert::NetworkToHost(InboundPacket.Size); - if (InboundPacket.Size == 0) // Can't have a size 0 packet - { - return NETWORK_READPACKET_DISCONNECTED; - } - InboundPacket.Data.resize(InboundPacket.Size); + // If still not enough data for header, keep waiting. + return NETWORK_READPACKET_MORE_DATA; } + + // Normalise values. + header.Size = Convert::NetworkToHost(header.Size); + header.Id = ByteSwapBE(header.Id); + + // NOTE: For compatibility reasons we need to remove sizeof(Header.Id) from the size + // Previously the Id field was not part of the header rather part of the body + header.Size -= sizeof(header.Id); + + // Fall-through: Read rest of packet. } - else + + // Read packet body. { - // read packet data - if (InboundPacket.Data.capacity() > 0) + const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header)); + + uint8_t buffer[NetworkBufferSize]; + + if (missingLength > 0) { - void* buffer = &InboundPacket.GetData()[InboundPacket.BytesTransferred - sizeof(InboundPacket.Size)]; - size_t bufferLength = sizeof(InboundPacket.Size) + InboundPacket.Size - InboundPacket.BytesTransferred; - size_t readBytes; - NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes); + NETWORK_READPACKET status = Socket->ReceiveData(buffer, std::min(missingLength, NetworkBufferSize), &bytesRead); if (status != NETWORK_READPACKET_SUCCESS) { return status; } - InboundPacket.BytesTransferred += readBytes; + InboundPacket.BytesTransferred += bytesRead; + InboundPacket.Write(buffer, bytesRead); } - if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size) + InboundPacket.Size) + + if (InboundPacket.Data.size() == header.Size) { + // Received complete packet. _lastPacketTime = platform_get_ticks(); RecordPacketStats(InboundPacket, false); @@ -79,26 +94,34 @@ int32_t NetworkConnection::ReadPacket() return NETWORK_READPACKET_SUCCESS; } } + return NETWORK_READPACKET_MORE_DATA; } bool NetworkConnection::SendPacket(NetworkPacket& packet) { - uint16_t sizen = Convert::HostToNetwork(packet.Size); - std::vector tosend; - tosend.reserve(sizeof(sizen) + packet.Size); - tosend.insert(tosend.end(), reinterpret_cast(&sizen), reinterpret_cast(&sizen) + sizeof(sizen)); - tosend.insert(tosend.end(), packet.Data.begin(), packet.Data.end()); + auto header = packet.Header; - const void* buffer = &tosend[packet.BytesTransferred]; - size_t bufferSize = tosend.size() - packet.BytesTransferred; - size_t sent = Socket->SendData(buffer, bufferSize); + std::vector buffer; + buffer.reserve(sizeof(header) + header.Size); + + // NOTE: For compatibility reasons we need to add sizeof(Header.Id) to the size + // Previously the Id field was not part of the header rather part of the body + header.Size += sizeof(header.Id); + header.Size = Convert::HostToNetwork(header.Size); + header.Id = ByteSwapBE(header.Id); + + buffer.insert(buffer.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); + buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end()); + + size_t bufferSize = buffer.size() - packet.BytesTransferred; + size_t sent = Socket->SendData(buffer.data() + packet.BytesTransferred, bufferSize); if (sent > 0) { packet.BytesTransferred += sent; } - bool sendComplete = packet.BytesTransferred == tosend.size(); + bool sendComplete = packet.BytesTransferred == buffer.size(); if (sendComplete) { RecordPacketStats(packet, true); @@ -110,7 +133,7 @@ void NetworkConnection::QueuePacket(NetworkPacket&& packet, bool front) { if (AuthStatus == NETWORK_AUTH_OK || !packet.CommandRequiresAuth()) { - packet.Size = static_cast(packet.Data.size()); + packet.Header.Size = static_cast(packet.Data.size()); if (front) { // If the first packet was already partially sent add new packet to second position diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index bf285329e0..c781fd1978 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -16,8 +16,8 @@ # include NetworkPacket::NetworkPacket(NetworkCommand id) + : Header{ 0, id } { - *this << static_cast::type>(id); } uint8_t* NetworkPacket::GetData() @@ -32,13 +32,7 @@ const uint8_t* NetworkPacket::GetData() const NetworkCommand NetworkPacket::GetCommand() const { - if (Data.size() < sizeof(uint32_t)) - return NetworkCommand::Invalid; - - uint32_t commandId = 0; - std::memcpy(&commandId, GetData(), sizeof(commandId)); - - return static_cast(ByteSwapBE(commandId)); + return Header.Id; } void NetworkPacket::Clear() @@ -78,7 +72,7 @@ void NetworkPacket::WriteString(const utf8* string) const uint8_t* NetworkPacket::Read(size_t size) { - if (BytesRead + size > NetworkPacket::Size) + if (BytesRead + size > Header.Size) { return nullptr; } @@ -94,7 +88,7 @@ const utf8* NetworkPacket::ReadString() { char* str = reinterpret_cast(&GetData()[BytesRead]); char* strend = str; - while (BytesRead < Size && *strend != 0) + while (BytesRead < Header.Size && *strend != 0) { BytesRead++; strend++; diff --git a/src/openrct2/network/NetworkPacket.h b/src/openrct2/network/NetworkPacket.h index bd376f993d..ffa464333a 100644 --- a/src/openrct2/network/NetworkPacket.h +++ b/src/openrct2/network/NetworkPacket.h @@ -16,13 +16,17 @@ #include #include -struct NetworkPacket final +#pragma pack(push, 1) +struct PacketHeader { uint16_t Size = 0; - std::vector Data; - size_t BytesTransferred = 0; - size_t BytesRead = 0; + NetworkCommand Id = NetworkCommand::Invalid; +}; +static_assert(sizeof(PacketHeader) == 6); +#pragma pack(pop) +struct NetworkPacket final +{ NetworkPacket() = default; NetworkPacket(NetworkCommand id); @@ -42,7 +46,7 @@ struct NetworkPacket final template NetworkPacket& operator>>(T& value) { - if (BytesRead + sizeof(value) > Size) + if (BytesRead + sizeof(value) > Header.Size) { value = T{}; } @@ -68,4 +72,10 @@ struct NetworkPacket final Write(static_cast(data.GetStream().GetData()), data.GetStream().GetLength()); return *this; } + +public: + PacketHeader Header{}; + std::vector Data; + size_t BytesTransferred = 0; + size_t BytesRead = 0; };