diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index e995de63f7..7f3578ff38 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -2488,6 +2488,12 @@ void NetworkBase::Server_Handle_MAPREQUEST(NetworkConnection& connection, Networ for (uint32_t i = 0; i < size; i++) { const char* name = reinterpret_cast(packet.Read(8)); + if (name == nullptr) + { + log_error("Client sent malformed object request data %s", connection.Socket->GetHostName()); + return; + } + // This is required, as packet does not have null terminator std::string s(name, name + 8); log_verbose("Client requested object %s", s.c_str()); diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index ffb0e46e1d..14039d7888 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -61,13 +61,14 @@ NetworkReadPacket NetworkConnection::ReadPacket() // NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size. // Previously the Id field was not part of the header rather part of the body. - header.Size -= sizeof(header.Id); + header.Size -= std::min(header.Size, sizeof(header.Id)); // Fall-through: Read rest of packet. } // Read packet body. { + // NOTE: BytesTransfered includes the header length, this will not underflow. const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header)); uint8_t buffer[NetworkBufferSize]; diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index ee71eebf8b..9efca470db 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -73,13 +73,13 @@ void NetworkPacket::WriteString(const utf8* string) const uint8_t* NetworkPacket::Read(size_t size) { - if (BytesRead + size > Header.Size) + if (BytesRead + size > Data.size()) { return nullptr; } else { - uint8_t* data = &GetData()[BytesRead]; + const uint8_t* data = Data.data() + BytesRead; BytesRead += size; return data; } @@ -87,18 +87,24 @@ const uint8_t* NetworkPacket::Read(size_t size) const utf8* NetworkPacket::ReadString() { - char* str = reinterpret_cast(&GetData()[BytesRead]); - char* strend = str; - while (BytesRead < Header.Size && *strend != 0) + if (BytesRead >= Data.size()) + return nullptr; + + const char* str = reinterpret_cast(Data.data() + BytesRead); + + size_t stringLen = 0; + while (BytesRead < Data.size() && str[stringLen] != '\0') { BytesRead++; - strend++; + stringLen++; } - if (*strend != 0) - { + + if (str[stringLen] != '\0') return nullptr; - } + + // Skip null terminator. BytesRead++; + return str; }