From 6ea5959b2bf10bfbd35ebb4efb7d0957ddf1e91e Mon Sep 17 00:00:00 2001 From: Aaron van Geffen Date: Sun, 31 Aug 2025 15:26:10 +0200 Subject: [PATCH] Move network units to OpenRCT2::Network namespace --- src/openrct2/network/DiscordService.cpp | 226 +- src/openrct2/network/DiscordService.h | 23 +- src/openrct2/network/Network.h | 163 +- src/openrct2/network/NetworkAction.cpp | 443 +- src/openrct2/network/NetworkAction.h | 83 +- src/openrct2/network/NetworkBase.cpp | 7961 +++++++++-------- src/openrct2/network/NetworkBase.h | 454 +- src/openrct2/network/NetworkClient.h | 9 +- src/openrct2/network/NetworkConnection.cpp | 363 +- src/openrct2/network/NetworkConnection.h | 69 +- src/openrct2/network/NetworkGroup.cpp | 183 +- src/openrct2/network/NetworkGroup.h | 63 +- src/openrct2/network/NetworkKey.cpp | 331 +- src/openrct2/network/NetworkKey.h | 39 +- src/openrct2/network/NetworkPacket.cpp | 173 +- src/openrct2/network/NetworkPacket.h | 113 +- src/openrct2/network/NetworkPlayer.cpp | 61 +- src/openrct2/network/NetworkPlayer.h | 56 +- src/openrct2/network/NetworkServer.h | 9 +- .../network/NetworkServerAdvertiser.cpp | 559 +- .../network/NetworkServerAdvertiser.h | 31 +- src/openrct2/network/NetworkTypes.h | 229 +- src/openrct2/network/NetworkUser.cpp | 321 +- src/openrct2/network/NetworkUser.h | 83 +- src/openrct2/network/ServerList.cpp | 727 +- src/openrct2/network/ServerList.h | 125 +- src/openrct2/network/Socket.cpp | 1665 ++-- src/openrct2/network/Socket.h | 145 +- 28 files changed, 7396 insertions(+), 7311 deletions(-) diff --git a/src/openrct2/network/DiscordService.cpp b/src/openrct2/network/DiscordService.cpp index c4b3f35d31..218a5b94d2 100644 --- a/src/openrct2/network/DiscordService.cpp +++ b/src/openrct2/network/DiscordService.cpp @@ -24,130 +24,128 @@ #include #include -using namespace OpenRCT2; - -namespace +namespace OpenRCT2::Network { using namespace std::chrono_literals; constexpr const char* kApplicationID = "378612438200877056"; constexpr const char* kSteamAppID = nullptr; constexpr auto kRefreshInterval = 5.0s; -} // namespace -static void OnReady([[maybe_unused]] const DiscordUser* request) -{ - LOG_VERBOSE("DiscordService::OnReady()"); -} - -static void OnDisconnected(int errorCode, const char* message) -{ - Console::Error::WriteLine("DiscordService::OnDisconnected(%d, %s)", errorCode, message); -} - -static void OnErrored(int errorCode, const char* message) -{ - Console::Error::WriteLine("DiscordService::OnErrored(%d, %s)", errorCode, message); -} - -DiscordService::DiscordService() -{ - DiscordEventHandlers handlers = {}; - handlers.ready = OnReady; - handlers.disconnected = OnDisconnected; - handlers.errored = OnErrored; - Discord_Initialize(kApplicationID, &handlers, 1, kSteamAppID); -} - -DiscordService::~DiscordService() -{ - Discord_Shutdown(); -} - -static std::string GetParkName() -{ - auto& gameState = getGameState(); - return gameState.park.name; -} - -void DiscordService::Tick() -{ - Discord_RunCallbacks(); - - if (_updateTimer.GetElapsedTime() < kRefreshInterval) - return; - - RefreshPresence(); - _updateTimer.Restart(); -} - -void DiscordService::RefreshPresence() const -{ - DiscordRichPresence discordPresence = {}; - discordPresence.largeImageKey = "logo"; - - std::string state; - std::string details; - std::string partyId; - - switch (gLegacyScene) + static void OnReady([[maybe_unused]] const DiscordUser* request) { - default: - details = GetParkName(); - if (NetworkGetMode() == NETWORK_MODE_NONE) - { - state = "Playing Solo"; - } - else - { - OpenRCT2::FmtString fmtServerName(NetworkGetServerName()); - std::string serverName; - for (const auto& token : fmtServerName) - { - if (token.IsLiteral()) - { - serverName += token.text; - } - else if (token.IsCodepoint()) - { - auto codepoint = token.GetCodepoint(); - char buffer[8]{}; - UTF8WriteCodepoint(buffer, codepoint); - serverName += buffer; - } - } - state = serverName; - - partyId = NetworkGetServerName(); - // NOTE: the party size is displayed next to state - discordPresence.partyId = partyId.c_str(); - discordPresence.partySize = NetworkGetNumPlayers(); - discordPresence.partyMax = 256; - - // TODO generate secrets for the server - discordPresence.matchSecret = nullptr; - discordPresence.spectateSecret = nullptr; - discordPresence.instance = 1; - } - break; - case LegacyScene::titleSequence: - details = "In Menus"; - break; - case LegacyScene::scenarioEditor: - details = "In Scenario Editor"; - break; - case LegacyScene::trackDesigner: - details = "In Track Designer"; - break; - case LegacyScene::trackDesignsManager: - details = "In Track Designs Manager"; - break; + LOG_VERBOSE("DiscordService::OnReady()"); } - discordPresence.state = state.c_str(); - discordPresence.details = details.c_str(); + static void OnDisconnected(int errorCode, const char* message) + { + Console::Error::WriteLine("DiscordService::OnDisconnected(%d, %s)", errorCode, message); + } - Discord_UpdatePresence(&discordPresence); -} + static void OnErrored(int errorCode, const char* message) + { + Console::Error::WriteLine("DiscordService::OnErrored(%d, %s)", errorCode, message); + } + + DiscordService::DiscordService() + { + DiscordEventHandlers handlers = {}; + handlers.ready = OnReady; + handlers.disconnected = OnDisconnected; + handlers.errored = OnErrored; + Discord_Initialize(kApplicationID, &handlers, 1, kSteamAppID); + } + + DiscordService::~DiscordService() + { + Discord_Shutdown(); + } + + static std::string GetParkName() + { + auto& gameState = getGameState(); + return gameState.park.name; + } + + void DiscordService::Tick() + { + Discord_RunCallbacks(); + + if (_updateTimer.GetElapsedTime() < kRefreshInterval) + return; + + RefreshPresence(); + _updateTimer.Restart(); + } + + void DiscordService::RefreshPresence() const + { + DiscordRichPresence discordPresence = {}; + discordPresence.largeImageKey = "logo"; + + std::string state; + std::string details; + std::string partyId; + + switch (gLegacyScene) + { + default: + details = GetParkName(); + if (NetworkGetMode() == NETWORK_MODE_NONE) + { + state = "Playing Solo"; + } + else + { + OpenRCT2::FmtString fmtServerName(NetworkGetServerName()); + std::string serverName; + for (const auto& token : fmtServerName) + { + if (token.IsLiteral()) + { + serverName += token.text; + } + else if (token.IsCodepoint()) + { + auto codepoint = token.GetCodepoint(); + char buffer[8]{}; + UTF8WriteCodepoint(buffer, codepoint); + serverName += buffer; + } + } + state = serverName; + + partyId = NetworkGetServerName(); + // NOTE: the party size is displayed next to state + discordPresence.partyId = partyId.c_str(); + discordPresence.partySize = NetworkGetNumPlayers(); + discordPresence.partyMax = 256; + + // TODO generate secrets for the server + discordPresence.matchSecret = nullptr; + discordPresence.spectateSecret = nullptr; + discordPresence.instance = 1; + } + break; + case LegacyScene::titleSequence: + details = "In Menus"; + break; + case LegacyScene::scenarioEditor: + details = "In Scenario Editor"; + break; + case LegacyScene::trackDesigner: + details = "In Track Designer"; + break; + case LegacyScene::trackDesignsManager: + details = "In Track Designs Manager"; + break; + } + + discordPresence.state = state.c_str(); + discordPresence.details = details.c_str(); + + Discord_UpdatePresence(&discordPresence); + } +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/DiscordService.h b/src/openrct2/network/DiscordService.h index 2945814dd5..75e15b3cdb 100644 --- a/src/openrct2/network/DiscordService.h +++ b/src/openrct2/network/DiscordService.h @@ -15,19 +15,22 @@ #include -class DiscordService final +namespace OpenRCT2::Network { -private: - OpenRCT2::Timer _updateTimer; + class DiscordService final + { + private: + OpenRCT2::Timer _updateTimer; -public: - DiscordService(); - ~DiscordService(); + public: + DiscordService(); + ~DiscordService(); - void Tick(); + void Tick(); -private: - void RefreshPresence() const; -}; + private: + void RefreshPresence() const; + }; +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/Network.h b/src/openrct2/network/Network.h index 8c1c17e7b0..4c908f65c3 100644 --- a/src/openrct2/network/Network.h +++ b/src/openrct2/network/Network.h @@ -19,12 +19,6 @@ #include #include -constexpr uint16_t kNetworkDefaultPort = 11753; -constexpr uint16_t kNetworkLanBroadcastPort = 11754; -constexpr const char* kNetworkLanBroadcastMsg = "openrct2.server.query"; -constexpr const char* kMasterServerURL = "https://servers.openrct2.io"; -constexpr uint16_t kMaxServerDescriptionLength = 256; - struct Peep; struct CoordsXYZ; @@ -36,85 +30,94 @@ namespace OpenRCT2::GameActions class Result; } // namespace OpenRCT2::GameActions -enum class NetworkPermission : uint32_t; +namespace OpenRCT2::Network +{ + constexpr uint16_t kNetworkDefaultPort = 11753; + constexpr uint16_t kNetworkLanBroadcastPort = 11754; + constexpr const char* kNetworkLanBroadcastMsg = "openrct2.server.query"; + constexpr const char* kMasterServerURL = "https://servers.openrct2.io"; + constexpr uint16_t kMaxServerDescriptionLength = 256; -void NetworkReconnect(); -void NetworkShutdownClient(); -int32_t NetworkBeginClient(const std::string& host, int32_t port); -int32_t NetworkBeginServer(int32_t port, const std::string& address); + enum class NetworkPermission : uint32_t; -[[nodiscard]] int32_t NetworkGetMode(); -[[nodiscard]] int32_t NetworkGetStatus(); -bool NetworkIsDesynchronised(); -bool NetworkCheckDesynchronisation(); -void NetworkRequestGamestateSnapshot(); -void NetworkSendTick(); -bool NetworkGamestateSnapshotsEnabled(); -void NetworkUpdate(); -void NetworkProcessPending(); -void NetworkFlush(); + void NetworkReconnect(); + void NetworkShutdownClient(); + int32_t NetworkBeginClient(const std::string& host, int32_t port); + int32_t NetworkBeginServer(int32_t port, const std::string& address); -[[nodiscard]] NetworkAuth NetworkGetAuthstatus(); -[[nodiscard]] uint32_t NetworkGetServerTick(); -[[nodiscard]] uint8_t NetworkGetCurrentPlayerId(); -[[nodiscard]] int32_t NetworkGetNumPlayers(); -[[nodiscard]] int32_t NetworkGetNumVisiblePlayers(); -[[nodiscard]] const char* NetworkGetPlayerName(uint32_t index); -[[nodiscard]] uint32_t NetworkGetPlayerFlags(uint32_t index); -[[nodiscard]] int32_t NetworkGetPlayerPing(uint32_t index); -[[nodiscard]] int32_t NetworkGetPlayerID(uint32_t index); -[[nodiscard]] money64 NetworkGetPlayerMoneySpent(uint32_t index); -[[nodiscard]] std::string NetworkGetPlayerIPAddress(uint32_t id); -[[nodiscard]] std::string NetworkGetPlayerPublicKeyHash(uint32_t id); -void NetworkIncrementPlayerNumCommands(uint32_t playerIndex); -void NetworkAddPlayerMoneySpent(uint32_t index, money64 cost); -[[nodiscard]] int32_t NetworkGetPlayerLastAction(uint32_t index, int32_t time); -void NetworkSetPlayerLastAction(uint32_t index, GameCommand command); -[[nodiscard]] CoordsXYZ NetworkGetPlayerLastActionCoord(uint32_t index); -void NetworkSetPlayerLastActionCoord(uint32_t index, const CoordsXYZ& coord); -[[nodiscard]] uint32_t NetworkGetPlayerCommandsRan(uint32_t index); -[[nodiscard]] int32_t NetworkGetPlayerIndex(uint32_t id); -[[nodiscard]] uint8_t NetworkGetPlayerGroup(uint32_t index); -void NetworkSetPlayerGroup(uint32_t index, uint32_t groupindex); -[[nodiscard]] int32_t NetworkGetGroupIndex(uint8_t id); -[[nodiscard]] int32_t NetworkGetCurrentPlayerGroupIndex(); -[[nodiscard]] uint8_t NetworkGetGroupID(uint32_t index); -[[nodiscard]] int32_t NetworkGetNumGroups(); -[[nodiscard]] const char* NetworkGetGroupName(uint32_t index); -[[nodiscard]] OpenRCT2::GameActions::Result NetworkSetPlayerGroup( - NetworkPlayerId_t actionPlayerId, NetworkPlayerId_t playerId, uint8_t groupId, bool isExecuting); -[[nodiscard]] OpenRCT2::GameActions::Result NetworkModifyGroups( - NetworkPlayerId_t actionPlayerId, OpenRCT2::GameActions::ModifyGroupType type, uint8_t groupId, const std::string& name, - uint32_t permissionIndex, OpenRCT2::GameActions::PermissionState permissionState, bool isExecuting); -[[nodiscard]] OpenRCT2::GameActions::Result NetworkKickPlayer(NetworkPlayerId_t playerId, bool isExecuting); -[[nodiscard]] uint8_t NetworkGetDefaultGroup(); -[[nodiscard]] int32_t NetworkGetNumActions(); -[[nodiscard]] StringId NetworkGetActionNameStringID(uint32_t index); -[[nodiscard]] int32_t NetworkCanPerformAction(uint32_t groupindex, NetworkPermission index); -[[nodiscard]] int32_t NetworkCanPerformCommand(uint32_t groupindex, int32_t index); -void NetworkSetPickupPeep(uint8_t playerid, Peep* peep); -[[nodiscard]] Peep* NetworkGetPickupPeep(uint8_t playerid); -void NetworkSetPickupPeepOldX(uint8_t playerid, int32_t x); -[[nodiscard]] int32_t NetworkGetPickupPeepOldX(uint8_t playerid); -[[nodiscard]] bool NetworkIsServerPlayerInvisible(); + [[nodiscard]] int32_t NetworkGetMode(); + [[nodiscard]] int32_t NetworkGetStatus(); + bool NetworkIsDesynchronised(); + bool NetworkCheckDesynchronisation(); + void NetworkRequestGamestateSnapshot(); + void NetworkSendTick(); + bool NetworkGamestateSnapshotsEnabled(); + void NetworkUpdate(); + void NetworkProcessPending(); + void NetworkFlush(); -void NetworkSendChat(const char* text, const std::vector& playerIds = {}); -void NetworkSendGameAction(const OpenRCT2::GameActions::GameAction* action); -void NetworkSendPassword(const std::string& password); + [[nodiscard]] NetworkAuth NetworkGetAuthstatus(); + [[nodiscard]] uint32_t NetworkGetServerTick(); + [[nodiscard]] uint8_t NetworkGetCurrentPlayerId(); + [[nodiscard]] int32_t NetworkGetNumPlayers(); + [[nodiscard]] int32_t NetworkGetNumVisiblePlayers(); + [[nodiscard]] const char* NetworkGetPlayerName(uint32_t index); + [[nodiscard]] uint32_t NetworkGetPlayerFlags(uint32_t index); + [[nodiscard]] int32_t NetworkGetPlayerPing(uint32_t index); + [[nodiscard]] int32_t NetworkGetPlayerID(uint32_t index); + [[nodiscard]] money64 NetworkGetPlayerMoneySpent(uint32_t index); + [[nodiscard]] std::string NetworkGetPlayerIPAddress(uint32_t id); + [[nodiscard]] std::string NetworkGetPlayerPublicKeyHash(uint32_t id); + void NetworkIncrementPlayerNumCommands(uint32_t playerIndex); + void NetworkAddPlayerMoneySpent(uint32_t index, money64 cost); + [[nodiscard]] int32_t NetworkGetPlayerLastAction(uint32_t index, int32_t time); + void NetworkSetPlayerLastAction(uint32_t index, GameCommand command); + [[nodiscard]] CoordsXYZ NetworkGetPlayerLastActionCoord(uint32_t index); + void NetworkSetPlayerLastActionCoord(uint32_t index, const CoordsXYZ& coord); + [[nodiscard]] uint32_t NetworkGetPlayerCommandsRan(uint32_t index); + [[nodiscard]] int32_t NetworkGetPlayerIndex(uint32_t id); + [[nodiscard]] uint8_t NetworkGetPlayerGroup(uint32_t index); + void NetworkSetPlayerGroup(uint32_t index, uint32_t groupindex); + [[nodiscard]] int32_t NetworkGetGroupIndex(uint8_t id); + [[nodiscard]] int32_t NetworkGetCurrentPlayerGroupIndex(); + [[nodiscard]] uint8_t NetworkGetGroupID(uint32_t index); + [[nodiscard]] int32_t NetworkGetNumGroups(); + [[nodiscard]] const char* NetworkGetGroupName(uint32_t index); + [[nodiscard]] OpenRCT2::GameActions::Result NetworkSetPlayerGroup( + NetworkPlayerId_t actionPlayerId, NetworkPlayerId_t playerId, uint8_t groupId, bool isExecuting); + [[nodiscard]] OpenRCT2::GameActions::Result NetworkModifyGroups( + NetworkPlayerId_t actionPlayerId, OpenRCT2::GameActions::ModifyGroupType type, uint8_t groupId, const std::string& name, + uint32_t permissionIndex, OpenRCT2::GameActions::PermissionState permissionState, bool isExecuting); + [[nodiscard]] OpenRCT2::GameActions::Result NetworkKickPlayer(NetworkPlayerId_t playerId, bool isExecuting); + [[nodiscard]] uint8_t NetworkGetDefaultGroup(); + [[nodiscard]] int32_t NetworkGetNumActions(); + [[nodiscard]] StringId NetworkGetActionNameStringID(uint32_t index); + [[nodiscard]] int32_t NetworkCanPerformAction(uint32_t groupindex, NetworkPermission index); + [[nodiscard]] int32_t NetworkCanPerformCommand(uint32_t groupindex, int32_t index); + void NetworkSetPickupPeep(uint8_t playerid, Peep* peep); + [[nodiscard]] Peep* NetworkGetPickupPeep(uint8_t playerid); + void NetworkSetPickupPeepOldX(uint8_t playerid, int32_t x); + [[nodiscard]] int32_t NetworkGetPickupPeepOldX(uint8_t playerid); + [[nodiscard]] bool NetworkIsServerPlayerInvisible(); -void NetworkSetPassword(const char* password); + void NetworkSendChat(const char* text, const std::vector& playerIds = {}); + void NetworkSendGameAction(const OpenRCT2::GameActions::GameAction* action); + void NetworkSendPassword(const std::string& password); -void NetworkAppendChatLog(std::string_view text); -void NetworkAppendServerLog(const utf8* text); -[[nodiscard]] u8string NetworkGetServerName(); -[[nodiscard]] u8string NetworkGetServerDescription(); -[[nodiscard]] u8string NetworkGetServerGreeting(); -[[nodiscard]] u8string NetworkGetServerProviderName(); -[[nodiscard]] u8string NetworkGetServerProviderEmail(); -[[nodiscard]] u8string NetworkGetServerProviderWebsite(); + void NetworkSetPassword(const char* password); -[[nodiscard]] std::string NetworkGetVersion(); + void NetworkAppendChatLog(std::string_view text); + void NetworkAppendServerLog(const utf8* text); + [[nodiscard]] u8string NetworkGetServerName(); + [[nodiscard]] u8string NetworkGetServerDescription(); + [[nodiscard]] u8string NetworkGetServerGreeting(); + [[nodiscard]] u8string NetworkGetServerProviderName(); + [[nodiscard]] u8string NetworkGetServerProviderEmail(); + [[nodiscard]] u8string NetworkGetServerProviderWebsite(); -[[nodiscard]] NetworkStats NetworkGetStats(); -[[nodiscard]] NetworkServerState NetworkGetServerState(); -[[nodiscard]] json_t NetworkGetServerInfoAsJson(); + [[nodiscard]] std::string NetworkGetVersion(); + + [[nodiscard]] NetworkStats NetworkGetStats(); + [[nodiscard]] NetworkServerState NetworkGetServerState(); + [[nodiscard]] json_t NetworkGetServerInfoAsJson(); +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkAction.cpp b/src/openrct2/network/NetworkAction.cpp index ce557c0d21..978d4a6051 100644 --- a/src/openrct2/network/NetworkAction.cpp +++ b/src/openrct2/network/NetworkAction.cpp @@ -16,254 +16,257 @@ #include -NetworkPermission NetworkActions::FindCommand(GameCommand command) +namespace OpenRCT2::Network { - auto it = std::find_if(Actions.begin(), Actions.end(), [&command](NetworkAction const& action) { - for (GameCommand currentCommand : action.Commands) - { - if (currentCommand == command) + NetworkPermission NetworkActions::FindCommand(GameCommand command) + { + auto it = std::find_if(Actions.begin(), Actions.end(), [&command](NetworkAction const& action) { + for (GameCommand currentCommand : action.Commands) { - return true; + if (currentCommand == command) + { + return true; + } } + return false; + }); + if (it != Actions.end()) + { + return static_cast(it - Actions.begin()); } - return false; - }); - if (it != Actions.end()) - { - return static_cast(it - Actions.begin()); + return NetworkPermission::Count; } - return NetworkPermission::Count; -} -NetworkPermission NetworkActions::FindCommandByPermissionName(const std::string& permission_name) -{ - auto it = std::find_if(Actions.begin(), Actions.end(), [&permission_name](NetworkAction const& action) { - return action.PermissionName == permission_name; - }); - if (it != Actions.end()) + NetworkPermission NetworkActions::FindCommandByPermissionName(const std::string& permission_name) { - return static_cast(it - Actions.begin()); + auto it = std::find_if(Actions.begin(), Actions.end(), [&permission_name](NetworkAction const& action) { + return action.PermissionName == permission_name; + }); + if (it != Actions.end()) + { + return static_cast(it - Actions.begin()); + } + return NetworkPermission::Count; } - return NetworkPermission::Count; -} -const std::array(NetworkPermission::Count)> NetworkActions::Actions = { - NetworkAction{ - STR_ACTION_CHAT, - "PERMISSION_CHAT", - {}, - }, - NetworkAction{ - STR_ACTION_TERRAFORM, - "PERMISSION_TERRAFORM", - { - GameCommand::SetLandHeight, - GameCommand::RaiseLand, - GameCommand::LowerLand, - GameCommand::EditLandSmooth, - GameCommand::ChangeSurfaceStyle, + const std::array(NetworkPermission::Count)> NetworkActions::Actions = { + NetworkAction{ + STR_ACTION_CHAT, + "PERMISSION_CHAT", + {}, }, - }, - NetworkAction{ - STR_ACTION_SET_WATER_LEVEL, - "PERMISSION_SET_WATER_LEVEL", - { - GameCommand::SetWaterHeight, - GameCommand::RaiseWater, - GameCommand::LowerWater, + NetworkAction{ + STR_ACTION_TERRAFORM, + "PERMISSION_TERRAFORM", + { + GameCommand::SetLandHeight, + GameCommand::RaiseLand, + GameCommand::LowerLand, + GameCommand::EditLandSmooth, + GameCommand::ChangeSurfaceStyle, + }, }, - }, - NetworkAction{ - STR_ACTION_TOGGLE_PAUSE, - "PERMISSION_TOGGLE_PAUSE", - { - GameCommand::TogglePause, + NetworkAction{ + STR_ACTION_SET_WATER_LEVEL, + "PERMISSION_SET_WATER_LEVEL", + { + GameCommand::SetWaterHeight, + GameCommand::RaiseWater, + GameCommand::LowerWater, + }, }, - }, - NetworkAction{ - STR_ACTION_CREATE_RIDE, - "PERMISSION_CREATE_RIDE", - { - GameCommand::CreateRide, + NetworkAction{ + STR_ACTION_TOGGLE_PAUSE, + "PERMISSION_TOGGLE_PAUSE", + { + GameCommand::TogglePause, + }, }, - }, - NetworkAction{ - STR_ACTION_REMOVE_RIDE, - "PERMISSION_REMOVE_RIDE", - { - GameCommand::DemolishRide, + NetworkAction{ + STR_ACTION_CREATE_RIDE, + "PERMISSION_CREATE_RIDE", + { + GameCommand::CreateRide, + }, }, - }, - NetworkAction{ - STR_ACTION_BUILD_RIDE, - "PERMISSION_BUILD_RIDE", - { - GameCommand::PlaceTrack, - GameCommand::RemoveTrack, - GameCommand::SetMazeTrack, - GameCommand::PlaceTrackDesign, - GameCommand::PlaceMazeDesign, - GameCommand::PlaceRideEntranceOrExit, - GameCommand::RemoveRideEntranceOrExit, + NetworkAction{ + STR_ACTION_REMOVE_RIDE, + "PERMISSION_REMOVE_RIDE", + { + GameCommand::DemolishRide, + }, }, - }, - NetworkAction{ - STR_ACTION_RIDE_PROPERTIES, - "PERMISSION_RIDE_PROPERTIES", - { - GameCommand::SetRideName, - GameCommand::SetRideAppearance, - GameCommand::SetRideStatus, - GameCommand::SetRideVehicles, - GameCommand::SetRideSetting, - GameCommand::SetRidePrice, - GameCommand::SetBrakesSpeed, - GameCommand::SetColourScheme, + NetworkAction{ + STR_ACTION_BUILD_RIDE, + "PERMISSION_BUILD_RIDE", + { + GameCommand::PlaceTrack, + GameCommand::RemoveTrack, + GameCommand::SetMazeTrack, + GameCommand::PlaceTrackDesign, + GameCommand::PlaceMazeDesign, + GameCommand::PlaceRideEntranceOrExit, + GameCommand::RemoveRideEntranceOrExit, + }, }, - }, - NetworkAction{ - STR_ACTION_SCENERY, - "PERMISSION_SCENERY", - { - GameCommand::RemoveScenery, - GameCommand::PlaceScenery, - GameCommand::SetBrakesSpeed, - GameCommand::RemoveWall, - GameCommand::PlaceWall, - GameCommand::RemoveLargeScenery, - GameCommand::PlaceLargeScenery, - GameCommand::PlaceBanner, - GameCommand::RemoveBanner, - GameCommand::SetSceneryColour, - GameCommand::SetWallColour, - GameCommand::SetLargeSceneryColour, - GameCommand::SetBannerColour, - GameCommand::SetBannerName, - GameCommand::SetSignName, - GameCommand::SetBannerStyle, - GameCommand::SetSignStyle, + NetworkAction{ + STR_ACTION_RIDE_PROPERTIES, + "PERMISSION_RIDE_PROPERTIES", + { + GameCommand::SetRideName, + GameCommand::SetRideAppearance, + GameCommand::SetRideStatus, + GameCommand::SetRideVehicles, + GameCommand::SetRideSetting, + GameCommand::SetRidePrice, + GameCommand::SetBrakesSpeed, + GameCommand::SetColourScheme, + }, }, - }, - NetworkAction{ - STR_ACTION_PATH, - "PERMISSION_PATH", - { - GameCommand::PlacePath, - GameCommand::PlacePathLayout, - GameCommand::RemovePath, - GameCommand::PlaceFootpathAddition, - GameCommand::RemoveFootpathAddition, + NetworkAction{ + STR_ACTION_SCENERY, + "PERMISSION_SCENERY", + { + GameCommand::RemoveScenery, + GameCommand::PlaceScenery, + GameCommand::SetBrakesSpeed, + GameCommand::RemoveWall, + GameCommand::PlaceWall, + GameCommand::RemoveLargeScenery, + GameCommand::PlaceLargeScenery, + GameCommand::PlaceBanner, + GameCommand::RemoveBanner, + GameCommand::SetSceneryColour, + GameCommand::SetWallColour, + GameCommand::SetLargeSceneryColour, + GameCommand::SetBannerColour, + GameCommand::SetBannerName, + GameCommand::SetSignName, + GameCommand::SetBannerStyle, + GameCommand::SetSignStyle, + }, }, - }, - NetworkAction{ - STR_ACTION_CLEAR_LANDSCAPE, - "PERMISSION_CLEAR_LANDSCAPE", - { - GameCommand::ClearScenery, + NetworkAction{ + STR_ACTION_PATH, + "PERMISSION_PATH", + { + GameCommand::PlacePath, + GameCommand::PlacePathLayout, + GameCommand::RemovePath, + GameCommand::PlaceFootpathAddition, + GameCommand::RemoveFootpathAddition, + }, }, - }, - NetworkAction{ - STR_ACTION_GUEST, - "PERMISSION_GUEST", - { - GameCommand::SetGuestName, - GameCommand::PickupGuest, - GameCommand::BalloonPress, - GameCommand::GuestSetFlags, + NetworkAction{ + STR_ACTION_CLEAR_LANDSCAPE, + "PERMISSION_CLEAR_LANDSCAPE", + { + GameCommand::ClearScenery, + }, }, - }, - NetworkAction{ - STR_ACTION_STAFF, - "PERMISSION_STAFF", - { - GameCommand::HireNewStaffMember, - GameCommand::SetStaffPatrol, - GameCommand::FireStaffMember, - GameCommand::SetStaffOrders, - GameCommand::SetStaffCostume, - GameCommand::SetStaffColour, - GameCommand::SetStaffName, - GameCommand::PickupStaff, + NetworkAction{ + STR_ACTION_GUEST, + "PERMISSION_GUEST", + { + GameCommand::SetGuestName, + GameCommand::PickupGuest, + GameCommand::BalloonPress, + GameCommand::GuestSetFlags, + }, }, - }, - NetworkAction{ - STR_ACTION_PARK_PROPERTIES, - "PERMISSION_PARK_PROPERTIES", - { - GameCommand::SetParkName, - GameCommand::SetParkOpen, - GameCommand::SetParkEntranceFee, - GameCommand::SetLandOwnership, - GameCommand::BuyLandRights, - GameCommand::PlaceParkEntrance, - GameCommand::RemoveParkEntrance, - GameCommand::PlacePeepSpawn, - GameCommand::ChangeMapSize, + NetworkAction{ + STR_ACTION_STAFF, + "PERMISSION_STAFF", + { + GameCommand::HireNewStaffMember, + GameCommand::SetStaffPatrol, + GameCommand::FireStaffMember, + GameCommand::SetStaffOrders, + GameCommand::SetStaffCostume, + GameCommand::SetStaffColour, + GameCommand::SetStaffName, + GameCommand::PickupStaff, + }, }, - }, - NetworkAction{ - STR_ACTION_PARK_FUNDING, - "PERMISSION_PARK_FUNDING", - { - GameCommand::SetCurrentLoan, - GameCommand::SetResearchFunding, - GameCommand::StartMarketingCampaign, + NetworkAction{ + STR_ACTION_PARK_PROPERTIES, + "PERMISSION_PARK_PROPERTIES", + { + GameCommand::SetParkName, + GameCommand::SetParkOpen, + GameCommand::SetParkEntranceFee, + GameCommand::SetLandOwnership, + GameCommand::BuyLandRights, + GameCommand::PlaceParkEntrance, + GameCommand::RemoveParkEntrance, + GameCommand::PlacePeepSpawn, + GameCommand::ChangeMapSize, + }, }, - }, - NetworkAction{ - STR_ACTION_KICK_PLAYER, - "PERMISSION_KICK_PLAYER", - { - GameCommand::KickPlayer, + NetworkAction{ + STR_ACTION_PARK_FUNDING, + "PERMISSION_PARK_FUNDING", + { + GameCommand::SetCurrentLoan, + GameCommand::SetResearchFunding, + GameCommand::StartMarketingCampaign, + }, }, - }, - NetworkAction{ - STR_ACTION_MODIFY_GROUPS, - "PERMISSION_MODIFY_GROUPS", - { - GameCommand::ModifyGroups, + NetworkAction{ + STR_ACTION_KICK_PLAYER, + "PERMISSION_KICK_PLAYER", + { + GameCommand::KickPlayer, + }, }, - }, - NetworkAction{ - STR_ACTION_SET_PLAYER_GROUP, - "PERMISSION_SET_PLAYER_GROUP", - { - GameCommand::SetPlayerGroup, + NetworkAction{ + STR_ACTION_MODIFY_GROUPS, + "PERMISSION_MODIFY_GROUPS", + { + GameCommand::ModifyGroups, + }, }, - }, - NetworkAction{ - STR_ACTION_CHEAT, - "PERMISSION_CHEAT", - { - GameCommand::Cheat, - GameCommand::SetDate, - GameCommand::FreezeRideRating, + NetworkAction{ + STR_ACTION_SET_PLAYER_GROUP, + "PERMISSION_SET_PLAYER_GROUP", + { + GameCommand::SetPlayerGroup, + }, }, - }, - NetworkAction{ - STR_ACTION_TOGGLE_SCENERY_CLUSTER, - "PERMISSION_TOGGLE_SCENERY_CLUSTER", - {}, - }, - NetworkAction{ - STR_ACTION_PASSWORDLESS_LOGIN, - "PERMISSION_PASSWORDLESS_LOGIN", - {}, - }, - NetworkAction{ - STR_ACTION_MODIFY_TILE, - "PERMISSION_MODIFY_TILE", - { - GameCommand::ModifyTile, + NetworkAction{ + STR_ACTION_CHEAT, + "PERMISSION_CHEAT", + { + GameCommand::Cheat, + GameCommand::SetDate, + GameCommand::FreezeRideRating, + }, }, - }, - NetworkAction{ - STR_ACTION_EDIT_SCENARIO_OPTIONS, - "PERMISSION_EDIT_SCENARIO_OPTIONS", - { - GameCommand::EditScenarioOptions, + NetworkAction{ + STR_ACTION_TOGGLE_SCENERY_CLUSTER, + "PERMISSION_TOGGLE_SCENERY_CLUSTER", + {}, }, - }, -}; + NetworkAction{ + STR_ACTION_PASSWORDLESS_LOGIN, + "PERMISSION_PASSWORDLESS_LOGIN", + {}, + }, + NetworkAction{ + STR_ACTION_MODIFY_TILE, + "PERMISSION_MODIFY_TILE", + { + GameCommand::ModifyTile, + }, + }, + NetworkAction{ + STR_ACTION_EDIT_SCENARIO_OPTIONS, + "PERMISSION_EDIT_SCENARIO_OPTIONS", + { + GameCommand::EditScenarioOptions, + }, + }, + }; +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/NetworkAction.h b/src/openrct2/network/NetworkAction.h index aa54671486..6cdde89dbf 100644 --- a/src/openrct2/network/NetworkAction.h +++ b/src/openrct2/network/NetworkAction.h @@ -16,48 +16,51 @@ #include #include -enum class NetworkPermission : uint32_t +namespace OpenRCT2::Network { - Chat, - Terraform, - SetWaterLevel, - TogglePause, - CreateRide, - RemoveRide, - BuildRide, - RideProperties, - Scenery, - Path, - ClearLandscape, - Guest, - Staff, - ParkProperties, - ParkFunding, - KickPlayer, - ModifyGroups, - SetPlayerGroup, - Cheat, - ToggleSceneryCluster, - PasswordlessLogin, - ModifyTile, - EditScenarioOptions, + enum class NetworkPermission : uint32_t + { + Chat, + Terraform, + SetWaterLevel, + TogglePause, + CreateRide, + RemoveRide, + BuildRide, + RideProperties, + Scenery, + Path, + ClearLandscape, + Guest, + Staff, + ParkProperties, + ParkFunding, + KickPlayer, + ModifyGroups, + SetPlayerGroup, + Cheat, + ToggleSceneryCluster, + PasswordlessLogin, + ModifyTile, + EditScenarioOptions, - Count -}; + Count + }; -class NetworkAction final -{ -public: - StringId Name; - std::string PermissionName; - std::vector Commands; -}; + class NetworkAction final + { + public: + StringId Name; + std::string PermissionName; + std::vector Commands; + }; -class NetworkActions final -{ -public: - static const std::array(NetworkPermission::Count)> Actions; + class NetworkActions final + { + public: + static const std::array(NetworkPermission::Count)> Actions; - static NetworkPermission FindCommand(GameCommand command); - static NetworkPermission FindCommandByPermissionName(const std::string& permission_name); -}; + static NetworkPermission FindCommand(GameCommand command); + static NetworkPermission FindCommandByPermissionName(const std::string& permission_name); + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkBase.cpp b/src/openrct2/network/NetworkBase.cpp index 067bd8095d..7cbe2b1c77 100644 --- a/src/openrct2/network/NetworkBase.cpp +++ b/src/openrct2/network/NetworkBase.cpp @@ -43,8 +43,6 @@ #include #include -using namespace OpenRCT2; - // This string specifies which version of network stream current build uses. // It is used for making sure only compatible builds get connected, even within // single OpenRCT2 version. @@ -104,4253 +102,4270 @@ static constexpr uint32_t kMaxPacketsPerUpdate = 100; #include #include -using namespace OpenRCT2; + using namespace OpenRCT2; -static void NetworkChatShowConnectedMessage(); -static void NetworkChatShowServerGreeting(); -static u8string NetworkGetKeysDirectory(); -static u8string NetworkGetPrivateKeyPath(u8string_view playerName); -static u8string NetworkGetPublicKeyPath(u8string_view playerName, u8string_view hash); + static void NetworkChatShowConnectedMessage(); + static void NetworkChatShowServerGreeting(); + static u8string NetworkGetKeysDirectory(); + static u8string NetworkGetPrivateKeyPath(u8string_view playerName); + static u8string NetworkGetPublicKeyPath(u8string_view playerName, u8string_view hash); -NetworkBase::NetworkBase(IContext& context) - : System(context) -{ - mode = NETWORK_MODE_NONE; - status = NETWORK_STATUS_NONE; - last_ping_sent_time = 0; - _actionId = 0; - - client_command_handlers[NetworkCommand::Auth] = &NetworkBase::Client_Handle_AUTH; - client_command_handlers[NetworkCommand::Map] = &NetworkBase::Client_Handle_MAP; - client_command_handlers[NetworkCommand::Chat] = &NetworkBase::Client_Handle_CHAT; - client_command_handlers[NetworkCommand::GameAction] = &NetworkBase::Client_Handle_GAME_ACTION; - client_command_handlers[NetworkCommand::Tick] = &NetworkBase::Client_Handle_TICK; - client_command_handlers[NetworkCommand::PlayerList] = &NetworkBase::Client_Handle_PLAYERLIST; - client_command_handlers[NetworkCommand::PlayerInfo] = &NetworkBase::Client_Handle_PLAYERINFO; - client_command_handlers[NetworkCommand::Ping] = &NetworkBase::Client_Handle_PING; - client_command_handlers[NetworkCommand::PingList] = &NetworkBase::Client_Handle_PINGLIST; - client_command_handlers[NetworkCommand::DisconnectMessage] = &NetworkBase::Client_Handle_SETDISCONNECTMSG; - client_command_handlers[NetworkCommand::ShowError] = &NetworkBase::Client_Handle_SHOWERROR; - client_command_handlers[NetworkCommand::GroupList] = &NetworkBase::Client_Handle_GROUPLIST; - client_command_handlers[NetworkCommand::Event] = &NetworkBase::Client_Handle_EVENT; - client_command_handlers[NetworkCommand::GameInfo] = &NetworkBase::Client_Handle_GAMEINFO; - client_command_handlers[NetworkCommand::Token] = &NetworkBase::Client_Handle_TOKEN; - client_command_handlers[NetworkCommand::ObjectsList] = &NetworkBase::Client_Handle_OBJECTS_LIST; - client_command_handlers[NetworkCommand::ScriptsHeader] = &NetworkBase::Client_Handle_SCRIPTS_HEADER; - client_command_handlers[NetworkCommand::ScriptsData] = &NetworkBase::Client_Handle_SCRIPTS_DATA; - client_command_handlers[NetworkCommand::GameState] = &NetworkBase::Client_Handle_GAMESTATE; - - server_command_handlers[NetworkCommand::Auth] = &NetworkBase::ServerHandleAuth; - server_command_handlers[NetworkCommand::Chat] = &NetworkBase::ServerHandleChat; - server_command_handlers[NetworkCommand::GameAction] = &NetworkBase::ServerHandleGameAction; - server_command_handlers[NetworkCommand::Ping] = &NetworkBase::ServerHandlePing; - server_command_handlers[NetworkCommand::GameInfo] = &NetworkBase::ServerHandleGameInfo; - server_command_handlers[NetworkCommand::Token] = &NetworkBase::ServerHandleToken; - server_command_handlers[NetworkCommand::MapRequest] = &NetworkBase::ServerHandleMapRequest; - server_command_handlers[NetworkCommand::RequestGameState] = &NetworkBase::ServerHandleRequestGamestate; - server_command_handlers[NetworkCommand::Heartbeat] = &NetworkBase::ServerHandleHeartbeat; - - _chat_log_fs << std::unitbuf; - _server_log_fs << std::unitbuf; -} - -bool NetworkBase::Init() -{ - status = NETWORK_STATUS_READY; - - ServerName.clear(); - ServerDescription.clear(); - ServerGreeting.clear(); - ServerProviderName.clear(); - ServerProviderEmail.clear(); - ServerProviderWebsite.clear(); - return true; -} - -void NetworkBase::Reconnect() -{ - if (status != NETWORK_STATUS_NONE) + NetworkBase::NetworkBase(IContext& context) + : System(context) { - Close(); - } - if (_requireClose) - { - _requireReconnect = true; - return; - } - BeginClient(_host, _port); -} + mode = NETWORK_MODE_NONE; + status = NETWORK_STATUS_NONE; + last_ping_sent_time = 0; + _actionId = 0; -void NetworkBase::Close() -{ - if (status != NETWORK_STATUS_NONE) - { - // HACK Because Close() is closed all over the place, it sometimes gets called inside an Update - // call. This then causes disposed data to be accessed. Therefore, save closing until the - // end of the update loop. - if (_closeLock) - { - _requireClose = true; - return; - } + client_command_handlers[NetworkCommand::Auth] = &NetworkBase::Client_Handle_AUTH; + client_command_handlers[NetworkCommand::Map] = &NetworkBase::Client_Handle_MAP; + client_command_handlers[NetworkCommand::Chat] = &NetworkBase::Client_Handle_CHAT; + client_command_handlers[NetworkCommand::GameAction] = &NetworkBase::Client_Handle_GAME_ACTION; + client_command_handlers[NetworkCommand::Tick] = &NetworkBase::Client_Handle_TICK; + client_command_handlers[NetworkCommand::PlayerList] = &NetworkBase::Client_Handle_PLAYERLIST; + client_command_handlers[NetworkCommand::PlayerInfo] = &NetworkBase::Client_Handle_PLAYERINFO; + client_command_handlers[NetworkCommand::Ping] = &NetworkBase::Client_Handle_PING; + client_command_handlers[NetworkCommand::PingList] = &NetworkBase::Client_Handle_PINGLIST; + client_command_handlers[NetworkCommand::DisconnectMessage] = &NetworkBase::Client_Handle_SETDISCONNECTMSG; + client_command_handlers[NetworkCommand::ShowError] = &NetworkBase::Client_Handle_SHOWERROR; + client_command_handlers[NetworkCommand::GroupList] = &NetworkBase::Client_Handle_GROUPLIST; + client_command_handlers[NetworkCommand::Event] = &NetworkBase::Client_Handle_EVENT; + client_command_handlers[NetworkCommand::GameInfo] = &NetworkBase::Client_Handle_GAMEINFO; + client_command_handlers[NetworkCommand::Token] = &NetworkBase::Client_Handle_TOKEN; + client_command_handlers[NetworkCommand::ObjectsList] = &NetworkBase::Client_Handle_OBJECTS_LIST; + client_command_handlers[NetworkCommand::ScriptsHeader] = &NetworkBase::Client_Handle_SCRIPTS_HEADER; + client_command_handlers[NetworkCommand::ScriptsData] = &NetworkBase::Client_Handle_SCRIPTS_DATA; + client_command_handlers[NetworkCommand::GameState] = &NetworkBase::Client_Handle_GAMESTATE; - CloseChatLog(); - CloseServerLog(); - CloseConnection(); + server_command_handlers[NetworkCommand::Auth] = &NetworkBase::ServerHandleAuth; + server_command_handlers[NetworkCommand::Chat] = &NetworkBase::ServerHandleChat; + server_command_handlers[NetworkCommand::GameAction] = &NetworkBase::ServerHandleGameAction; + server_command_handlers[NetworkCommand::Ping] = &NetworkBase::ServerHandlePing; + server_command_handlers[NetworkCommand::GameInfo] = &NetworkBase::ServerHandleGameInfo; + server_command_handlers[NetworkCommand::Token] = &NetworkBase::ServerHandleToken; + server_command_handlers[NetworkCommand::MapRequest] = &NetworkBase::ServerHandleMapRequest; + server_command_handlers[NetworkCommand::RequestGameState] = &NetworkBase::ServerHandleRequestGamestate; + server_command_handlers[NetworkCommand::Heartbeat] = &NetworkBase::ServerHandleHeartbeat; - client_connection_list.clear(); - GameActions::ClearQueue(); - GameActions::ResumeQueue(); - player_list.clear(); - group_list.clear(); - _serverTickData.clear(); - _pendingPlayerLists.clear(); - _pendingPlayerInfo.clear(); - - #ifdef ENABLE_SCRIPTING - auto& scriptEngine = GetContext().GetScriptEngine(); - scriptEngine.RemoveNetworkPlugins(); - #endif - - GfxInvalidateScreen(); - - _requireClose = false; - } -} - -void NetworkBase::DecayCooldown(NetworkPlayer* player) -{ - if (player == nullptr) - return; // No valid connection yet. - - for (auto it = std::begin(player->CooldownTime); it != std::end(player->CooldownTime);) - { - it->second -= _currentDeltaTime; - if (it->second <= 0) - it = player->CooldownTime.erase(it); - else - it++; - } -} - -void NetworkBase::CloseConnection() -{ - if (mode == NETWORK_MODE_CLIENT) - { - _serverConnection.reset(); - } - else if (mode == NETWORK_MODE_SERVER) - { - _listenSocket.reset(); - _advertiser.reset(); + _chat_log_fs << std::unitbuf; + _server_log_fs << std::unitbuf; } - mode = NETWORK_MODE_NONE; - status = NETWORK_STATUS_NONE; - _lastConnectStatus = SocketStatus::Closed; -} - -bool NetworkBase::BeginClient(const std::string& host, uint16_t port) -{ - if (GetMode() != NETWORK_MODE_NONE) + bool NetworkBase::Init() { - return false; - } - - Close(); - if (!Init()) - return false; - - mode = NETWORK_MODE_CLIENT; - - LOG_INFO("Connecting to %s:%u", host.c_str(), port); - _host = host; - _port = port; - - _serverConnection = std::make_unique(); - _serverConnection->Socket = CreateTcpSocket(); - _serverConnection->Socket->ConnectAsync(host, port); - _serverState.gamestateSnapshotsEnabled = false; - - status = NETWORK_STATUS_CONNECTING; - _lastConnectStatus = SocketStatus::Closed; - _clientMapLoaded = false; - _serverTickData.clear(); - - BeginChatLog(); - BeginServerLog(); - - // We need to wait for the map load before we execute any actions. - // If the client has the title screen running then there's a potential - // risk of tick collision with the server map and title screen map. - GameActions::SuspendQueue(); - - auto keyPath = NetworkGetPrivateKeyPath(Config::Get().network.PlayerName); - if (!File::Exists(keyPath)) - { - Console::WriteLine("Generating key... This may take a while"); - Console::WriteLine("Need to collect enough entropy from the system"); - _key.Generate(); - Console::WriteLine("Key generated, saving private bits as %s", keyPath.c_str()); - - const auto keysDirectory = NetworkGetKeysDirectory(); - if (!Path::CreateDirectory(keysDirectory)) - { - LOG_ERROR("Unable to create directory %s.", keysDirectory.c_str()); - return false; - } - - try - { - auto fs = FileStream(keyPath, FileMode::write); - _key.SavePrivate(&fs); - } - catch (const std::exception&) - { - LOG_ERROR("Unable to save private key at %s.", keyPath.c_str()); - return false; - } - - const std::string hash = _key.PublicKeyHash(); - const utf8* publicKeyHash = hash.c_str(); - keyPath = NetworkGetPublicKeyPath(Config::Get().network.PlayerName, publicKeyHash); - Console::WriteLine("Key generated, saving public bits as %s", keyPath.c_str()); - - try - { - auto fs = FileStream(keyPath, FileMode::write); - _key.SavePublic(&fs); - } - catch (const std::exception&) - { - LOG_ERROR("Unable to save public key at %s.", keyPath.c_str()); - return false; - } - } - else - { - // LoadPrivate returns validity of loaded key - bool ok = false; - try - { - LOG_VERBOSE("Loading key from %s", keyPath.c_str()); - auto fs = FileStream(keyPath, FileMode::open); - ok = _key.LoadPrivate(&fs); - } - catch (const std::exception&) - { - LOG_ERROR("Unable to read private key from %s.", keyPath.c_str()); - return false; - } - - // Don't store private key in memory when it's not in use. - _key.Unload(); - return ok; - } - - return true; -} - -bool NetworkBase::BeginServer(uint16_t port, const std::string& address) -{ - Close(); - if (!Init()) - return false; - - mode = NETWORK_MODE_SERVER; - - _userManager.Load(); - - LOG_VERBOSE("Begin listening for clients"); - - _listenSocket = CreateTcpSocket(); - try - { - _listenSocket->Listen(address, port); - } - catch (const std::exception& ex) - { - Console::Error::WriteLine(ex.what()); - Close(); - return false; - } - - ServerName = Config::Get().network.ServerName; - ServerDescription = Config::Get().network.ServerDescription; - ServerGreeting = Config::Get().network.ServerGreeting; - ServerProviderName = Config::Get().network.ProviderName; - ServerProviderEmail = Config::Get().network.ProviderEmail; - ServerProviderWebsite = Config::Get().network.ProviderWebsite; - - IsServerPlayerInvisible = gOpenRCT2Headless; - - LoadGroups(); - BeginChatLog(); - BeginServerLog(); - - NetworkPlayer* player = AddPlayer(Config::Get().network.PlayerName, ""); - player->Flags |= NETWORK_PLAYER_FLAG_ISSERVER; - player->Group = 0; - player_id = player->Id; - - if (NetworkGetMode() == NETWORK_MODE_SERVER) - { - // Add SERVER to users.json and save. - NetworkUser* networkUser = _userManager.GetOrAddUser(player->KeyHash); - networkUser->GroupId = player->Group; - networkUser->Name = player->Name; - _userManager.Save(); - } - - auto* szAddress = address.empty() ? "*" : address.c_str(); - Console::WriteLine("Listening for clients on %s:%hu", szAddress, port); - NetworkChatShowConnectedMessage(); - NetworkChatShowServerGreeting(); - - status = NETWORK_STATUS_CONNECTED; - listening_port = port; - _serverState.gamestateSnapshotsEnabled = Config::Get().network.DesyncDebugging; - _advertiser = CreateServerAdvertiser(listening_port); - - GameLoadScripts(); - GameNotifyMapChanged(); - - return true; -} - -int32_t NetworkBase::GetMode() const noexcept -{ - return mode; -} - -int32_t NetworkBase::GetStatus() const noexcept -{ - return status; -} - -NetworkAuth NetworkBase::GetAuthStatus() -{ - if (GetMode() == NETWORK_MODE_CLIENT) - { - return _serverConnection->AuthStatus; - } - if (GetMode() == NETWORK_MODE_SERVER) - { - return NetworkAuth::Ok; - } - return NetworkAuth::None; -} - -uint32_t NetworkBase::GetServerTick() const noexcept -{ - return _serverState.tick; -} - -uint8_t NetworkBase::GetPlayerID() const noexcept -{ - return player_id; -} - -NetworkConnection* NetworkBase::GetPlayerConnection(uint8_t id) const -{ - auto player = GetPlayerByID(id); - if (player != nullptr) - { - auto clientIt = std::find_if( - client_connection_list.begin(), client_connection_list.end(), - [player](const auto& conn) -> bool { return conn->Player == player; }); - return clientIt != client_connection_list.end() ? clientIt->get() : nullptr; - } - return nullptr; -} - -void NetworkBase::Update() -{ - _closeLock = true; - - // Update is not necessarily called per game tick, maintain our own delta time - uint32_t ticks = Platform::GetTicks(); - _currentDeltaTime = std::max(ticks - _lastUpdateTime, 1); - _lastUpdateTime = ticks; - - switch (GetMode()) - { - case NETWORK_MODE_SERVER: - UpdateServer(); - break; - case NETWORK_MODE_CLIENT: - UpdateClient(); - break; - } - - // If the Close() was called during the update, close it for real - _closeLock = false; - if (_requireClose) - { - Close(); - if (_requireReconnect) - { - Reconnect(); - } - } -} - -void NetworkBase::Flush() -{ - if (GetMode() == NETWORK_MODE_CLIENT) - { - _serverConnection->SendQueuedData(); - } - else - { - for (auto& it : client_connection_list) - { - it->SendQueuedData(); - } - } -} - -void NetworkBase::UpdateServer() -{ - for (auto& connection : client_connection_list) - { - // This can be called multiple times before the connection is removed. - if (!connection->IsValid()) - continue; - - if (!ProcessConnection(*connection)) - { - connection->Disconnect(); - } - else - { - DecayCooldown(connection->Player); - } - } - - uint32_t ticks = Platform::GetTicks(); - if (ticks > last_ping_sent_time + 3000) - { - ServerSendPing(); - ServerSendPingList(); - } - - if (_advertiser != nullptr) - { - _advertiser->Update(); - } - - std::unique_ptr tcpSocket = _listenSocket->Accept(); - if (tcpSocket != nullptr) - { - AddClient(std::move(tcpSocket)); - } -} - -void NetworkBase::UpdateClient() -{ - assert(_serverConnection != nullptr); - - switch (status) - { - case NETWORK_STATUS_CONNECTING: - { - switch (_serverConnection->Socket->GetStatus()) - { - case SocketStatus::Resolving: - { - if (_lastConnectStatus != SocketStatus::Resolving) - { - _lastConnectStatus = SocketStatus::Resolving; - char str_resolving[256]; - FormatStringLegacy(str_resolving, 256, STR_MULTIPLAYER_RESOLVING, nullptr); - - auto intent = Intent(WindowClass::NetworkStatus); - intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_resolving }); - intent.PutExtra(INTENT_EXTRA_CALLBACK, []() -> void { ::GetContext()->GetNetwork().Close(); }); - ContextOpenIntent(&intent); - } - break; - } - case SocketStatus::Connecting: - { - if (_lastConnectStatus != SocketStatus::Connecting) - { - _lastConnectStatus = SocketStatus::Connecting; - char str_connecting[256]; - FormatStringLegacy(str_connecting, 256, STR_MULTIPLAYER_CONNECTING, nullptr); - - auto intent = Intent(WindowClass::NetworkStatus); - intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_connecting }); - intent.PutExtra(INTENT_EXTRA_CALLBACK, []() -> void { ::GetContext()->GetNetwork().Close(); }); - ContextOpenIntent(&intent); - - server_connect_time = Platform::GetTicks(); - } - break; - } - case SocketStatus::Connected: - { - status = NETWORK_STATUS_CONNECTED; - _serverConnection->ResetLastPacketTime(); - Client_Send_TOKEN(); - char str_authenticating[256]; - FormatStringLegacy(str_authenticating, 256, STR_MULTIPLAYER_AUTHENTICATING, nullptr); - - auto intent = Intent(WindowClass::NetworkStatus); - intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_authenticating }); - intent.PutExtra(INTENT_EXTRA_CALLBACK, []() -> void { ::GetContext()->GetNetwork().Close(); }); - ContextOpenIntent(&intent); - break; - } - default: - { - const char* error = _serverConnection->Socket->GetError(); - if (error != nullptr) - { - Console::Error::WriteLine(error); - } - - Close(); - ContextForceCloseWindowByClass(WindowClass::NetworkStatus); - ContextShowError(STR_UNABLE_TO_CONNECT_TO_SERVER, kStringIdNone, {}); - break; - } - } - break; - } - case NETWORK_STATUS_CONNECTED: - { - if (!ProcessConnection(*_serverConnection)) - { - // Do not show disconnect message window when password window closed/canceled - if (_serverConnection->AuthStatus == NetworkAuth::RequirePassword) - { - ContextForceCloseWindowByClass(WindowClass::NetworkStatus); - } - else - { - char str_disconnected[256]; - - if (_serverConnection->GetLastDisconnectReason()) - { - const char* disconnect_reason = _serverConnection->GetLastDisconnectReason(); - FormatStringLegacy(str_disconnected, 256, STR_MULTIPLAYER_DISCONNECTED_WITH_REASON, &disconnect_reason); - } - else - { - FormatStringLegacy(str_disconnected, 256, STR_MULTIPLAYER_DISCONNECTED_NO_REASON, nullptr); - } - - auto intent = Intent(WindowClass::NetworkStatus); - intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_disconnected }); - ContextOpenIntent(&intent); - } - - auto* windowMgr = Ui::GetWindowManager(); - windowMgr->CloseByClass(WindowClass::Multiplayer); - Close(); - } - else - { - uint32_t ticks = Platform::GetTicks(); - if (ticks - _lastSentHeartbeat >= 3000) - { - Client_Send_HEARTBEAT(*_serverConnection); - _lastSentHeartbeat = ticks; - } - } - - break; - } - } -} - -auto NetworkBase::GetPlayerIteratorByID(uint8_t id) const -{ - return std::find_if(player_list.begin(), player_list.end(), [id](std::unique_ptr const& player) { - return player->Id == id; - }); -} - -NetworkPlayer* NetworkBase::GetPlayerByID(uint8_t id) const -{ - auto it = GetPlayerIteratorByID(id); - if (it != player_list.end()) - { - return it->get(); - } - return nullptr; -} - -auto NetworkBase::GetGroupIteratorByID(uint8_t id) const -{ - return std::find_if( - group_list.begin(), group_list.end(), [id](std::unique_ptr const& group) { return group->Id == id; }); -} - -NetworkGroup* NetworkBase::GetGroupByID(uint8_t id) const -{ - auto it = GetGroupIteratorByID(id); - if (it != group_list.end()) - { - return it->get(); - } - return nullptr; -} - -int32_t NetworkBase::GetTotalNumPlayers() const noexcept -{ - return static_cast(player_list.size()); -} - -int32_t NetworkBase::GetNumVisiblePlayers() const noexcept -{ - if (IsServerPlayerInvisible) - return static_cast(player_list.size() - 1); - return static_cast(player_list.size()); -} - -const char* NetworkBase::FormatChat(NetworkPlayer* fromPlayer, const char* text) -{ - static std::string formatted; - formatted.clear(); - - if (fromPlayer != nullptr) - { - auto& network = OpenRCT2::GetContext()->GetNetwork(); - auto it = network.GetGroupByID(fromPlayer->Id); - std::string groupName = ""; - std::vector colours; - if (it != nullptr) - { - groupName = it->GetName(); - if (groupName[0] != '{') - { - colours.push_back("{WHITE}"); - } - } - - for (size_t i = 0; i < groupName.size(); ++i) - { - if (groupName[i] == '{') - { - std::string colour = "{"; - ++i; - while (i < groupName.size() && groupName[i] != '}' && groupName[i] != '{') - { - colour += groupName[i]; - ++i; - } - colour += '}'; - if (groupName[i] == '}' && i < groupName.size()) - { - colours.push_back(colour); - } - } - } - - if (colours.size() == 0 || (colours.size() == 1 && colours[0] == "{WHITE}")) - { - formatted += "{BABYBLUE}"; - formatted += fromPlayer->Name; - } - else - { - size_t j = 0; - size_t proportionalSize = fromPlayer->Name.size() / colours.size(); - for (size_t i = 0; i < colours.size(); ++i) - { - formatted += colours[i]; - size_t numCharacters = proportionalSize + j; - for (; j < numCharacters && j < fromPlayer->Name.size(); ++j) - { - formatted += fromPlayer->Name[j]; - } - } - while (j < fromPlayer->Name.size()) - { - formatted += fromPlayer->Name[j]; - j++; - } - } - - formatted += ": "; - } - formatted += "{WHITE}"; - formatted += text; - return formatted.c_str(); -} - -void NetworkBase::SendPacketToClients(const NetworkPacket& packet, bool front, bool gameCmd) const -{ - for (auto& client_connection : client_connection_list) - { - if (gameCmd) - { - // If marked as game command we can not send the packet to connections that are not fully connected. - // Sending the packet would cause the client to store a command that is behind the tick where he starts, - // which would be essentially never executed. The clients do not require commands before the server has not sent the - // map data. - if (client_connection->Player == nullptr) - { - continue; - } - } - client_connection->QueuePacket(packet, front); - } -} - -bool NetworkBase::CheckSRAND(uint32_t tick, uint32_t srand0) -{ - // We have to wait for the map to be loaded first, ticks may match current loaded map. - if (!_clientMapLoaded) + status = NETWORK_STATUS_READY; + + ServerName.clear(); + ServerDescription.clear(); + ServerGreeting.clear(); + ServerProviderName.clear(); + ServerProviderEmail.clear(); + ServerProviderWebsite.clear(); return true; - - auto itTickData = _serverTickData.find(tick); - if (itTickData == std::end(_serverTickData)) - return true; - - const ServerTickData storedTick = itTickData->second; - _serverTickData.erase(itTickData); - - if (storedTick.srand0 != srand0) - { - LOG_INFO("Srand0 mismatch, client = %08X, server = %08X", srand0, storedTick.srand0); - return false; } - if (!storedTick.spriteHash.empty()) + void NetworkBase::Reconnect() { - EntitiesChecksum checksum = getGameState().entities.GetAllEntitiesChecksum(); - std::string clientSpriteHash = checksum.ToString(); - if (clientSpriteHash != storedTick.spriteHash) - { - LOG_INFO("Sprite hash mismatch, client = %s, server = %s", clientSpriteHash.c_str(), storedTick.spriteHash.c_str()); - return false; - } - } - - return true; -} - -bool NetworkBase::IsDesynchronised() const noexcept -{ - return _serverState.state == NetworkServerStatus::Desynced; -} - -bool NetworkBase::CheckDesynchronizaton() -{ - const auto currentTicks = getGameState().currentTicks; - - // Check synchronisation - if (GetMode() == NETWORK_MODE_CLIENT && _serverState.state != NetworkServerStatus::Desynced - && !CheckSRAND(currentTicks, ScenarioRandState().s0)) - { - _serverState.state = NetworkServerStatus::Desynced; - _serverState.desyncTick = currentTicks; - - char str_desync[256]; - FormatStringLegacy(str_desync, 256, STR_MULTIPLAYER_DESYNC, nullptr); - - auto intent = Intent(WindowClass::NetworkStatus); - intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_desync }); - ContextOpenIntent(&intent); - - if (!Config::Get().network.StayConnected) + if (status != NETWORK_STATUS_NONE) { Close(); } + if (_requireClose) + { + _requireReconnect = true; + return; + } + BeginClient(_host, _port); + } + + void NetworkBase::Close() + { + if (status != NETWORK_STATUS_NONE) + { + // HACK Because Close() is closed all over the place, it sometimes gets called inside an Update + // call. This then causes disposed data to be accessed. Therefore, save closing until the + // end of the update loop. + if (_closeLock) + { + _requireClose = true; + return; + } + + CloseChatLog(); + CloseServerLog(); + CloseConnection(); + + client_connection_list.clear(); + GameActions::ClearQueue(); + GameActions::ResumeQueue(); + player_list.clear(); + group_list.clear(); + _serverTickData.clear(); + _pendingPlayerLists.clear(); + _pendingPlayerInfo.clear(); + + #ifdef ENABLE_SCRIPTING + auto& scriptEngine = GetContext().GetScriptEngine(); + scriptEngine.RemoveNetworkPlugins(); + #endif + + GfxInvalidateScreen(); + + _requireClose = false; + } + } + + void NetworkBase::DecayCooldown(NetworkPlayer* player) + { + if (player == nullptr) + return; // No valid connection yet. + + for (auto it = std::begin(player->CooldownTime); it != std::end(player->CooldownTime);) + { + it->second -= _currentDeltaTime; + if (it->second <= 0) + it = player->CooldownTime.erase(it); + else + it++; + } + } + + void NetworkBase::CloseConnection() + { + if (mode == NETWORK_MODE_CLIENT) + { + _serverConnection.reset(); + } + else if (mode == NETWORK_MODE_SERVER) + { + _listenSocket.reset(); + _advertiser.reset(); + } + + mode = NETWORK_MODE_NONE; + status = NETWORK_STATUS_NONE; + _lastConnectStatus = SocketStatus::Closed; + } + + bool NetworkBase::BeginClient(const std::string& host, uint16_t port) + { + if (GetMode() != NETWORK_MODE_NONE) + { + return false; + } + + Close(); + if (!Init()) + return false; + + mode = NETWORK_MODE_CLIENT; + + LOG_INFO("Connecting to %s:%u", host.c_str(), port); + _host = host; + _port = port; + + _serverConnection = std::make_unique(); + _serverConnection->Socket = CreateTcpSocket(); + _serverConnection->Socket->ConnectAsync(host, port); + _serverState.gamestateSnapshotsEnabled = false; + + status = NETWORK_STATUS_CONNECTING; + _lastConnectStatus = SocketStatus::Closed; + _clientMapLoaded = false; + _serverTickData.clear(); + + BeginChatLog(); + BeginServerLog(); + + // We need to wait for the map load before we execute any actions. + // If the client has the title screen running then there's a potential + // risk of tick collision with the server map and title screen map. + GameActions::SuspendQueue(); + + auto keyPath = NetworkGetPrivateKeyPath(Config::Get().network.PlayerName); + if (!File::Exists(keyPath)) + { + Console::WriteLine("Generating key... This may take a while"); + Console::WriteLine("Need to collect enough entropy from the system"); + _key.Generate(); + Console::WriteLine("Key generated, saving private bits as %s", keyPath.c_str()); + + const auto keysDirectory = NetworkGetKeysDirectory(); + if (!Path::CreateDirectory(keysDirectory)) + { + LOG_ERROR("Unable to create directory %s.", keysDirectory.c_str()); + return false; + } + + try + { + auto fs = FileStream(keyPath, FileMode::write); + _key.SavePrivate(&fs); + } + catch (const std::exception&) + { + LOG_ERROR("Unable to save private key at %s.", keyPath.c_str()); + return false; + } + + const std::string hash = _key.PublicKeyHash(); + const utf8* publicKeyHash = hash.c_str(); + keyPath = NetworkGetPublicKeyPath(Config::Get().network.PlayerName, publicKeyHash); + Console::WriteLine("Key generated, saving public bits as %s", keyPath.c_str()); + + try + { + auto fs = FileStream(keyPath, FileMode::write); + _key.SavePublic(&fs); + } + catch (const std::exception&) + { + LOG_ERROR("Unable to save public key at %s.", keyPath.c_str()); + return false; + } + } + else + { + // LoadPrivate returns validity of loaded key + bool ok = false; + try + { + LOG_VERBOSE("Loading key from %s", keyPath.c_str()); + auto fs = FileStream(keyPath, FileMode::open); + ok = _key.LoadPrivate(&fs); + } + catch (const std::exception&) + { + LOG_ERROR("Unable to read private key from %s.", keyPath.c_str()); + return false; + } + + // Don't store private key in memory when it's not in use. + _key.Unload(); + return ok; + } return true; } - return false; -} - -void NetworkBase::RequestStateSnapshot() -{ - LOG_INFO("Requesting game state for tick %u", _serverState.desyncTick); - - Client_Send_RequestGameState(_serverState.desyncTick); -} - -NetworkServerState NetworkBase::GetServerState() const noexcept -{ - return _serverState; -} - -void NetworkBase::KickPlayer(int32_t playerId) -{ - for (auto& client_connection : client_connection_list) + bool NetworkBase::BeginServer(uint16_t port, const std::string& address) { - if (client_connection->Player->Id == playerId) - { - // Disconnect the client gracefully - client_connection->SetLastDisconnectReason(STR_MULTIPLAYER_KICKED); - char str_disconnect_msg[256]; - FormatStringLegacy(str_disconnect_msg, 256, STR_MULTIPLAYER_KICKED_REASON, nullptr); - ServerSendSetDisconnectMsg(*client_connection, str_disconnect_msg); - client_connection->Disconnect(); - break; - } - } -} + Close(); + if (!Init()) + return false; -void NetworkBase::SetPassword(u8string_view password) -{ - _password = password; -} + mode = NETWORK_MODE_SERVER; -void NetworkBase::ServerClientDisconnected() -{ - if (GetMode() == NETWORK_MODE_CLIENT) - { - _serverConnection->Disconnect(); - } -} + _userManager.Load(); -std::string NetworkBase::GenerateAdvertiseKey() -{ - // Generate a string of 16 random hex characters (64-integer key as a hex formatted string) - static char hexChars[] = { - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', - }; - char key[17]; - for (int32_t i = 0; i < 16; i++) - { - int32_t hexCharIndex = UtilRand() % std::size(hexChars); - key[i] = hexChars[hexCharIndex]; - } - key[std::size(key) - 1] = 0; + LOG_VERBOSE("Begin listening for clients"); - return key; -} - -std::string NetworkBase::GetMasterServerUrl() -{ - if (Config::Get().network.MasterServerUrl.empty()) - { - return kMasterServerURL; - } - - return Config::Get().network.MasterServerUrl; -} - -NetworkGroup* NetworkBase::AddGroup() -{ - NetworkGroup* addedgroup = nullptr; - int32_t newid = -1; - // Find first unused group id - for (int32_t id = 0; id < 255; id++) - { - if (std::find_if( - group_list.begin(), group_list.end(), - [&id](std::unique_ptr const& group) { return group->Id == id; }) - == group_list.end()) - { - newid = id; - break; - } - } - if (newid != -1) - { - auto group = std::make_unique(); - group->Id = newid; - group->SetName("Group #" + std::to_string(newid)); - addedgroup = group.get(); - group_list.push_back(std::move(group)); - } - return addedgroup; -} - -void NetworkBase::RemoveGroup(uint8_t id) -{ - auto group = GetGroupIteratorByID(id); - if (group != group_list.end()) - { - group_list.erase(group); - } - - if (GetMode() == NETWORK_MODE_SERVER) - { - _userManager.UnsetUsersOfGroup(id); - _userManager.Save(); - } -} - -uint8_t NetworkBase::GetGroupIDByHash(const std::string& keyhash) -{ - const NetworkUser* networkUser = _userManager.GetUserByHash(keyhash); - - uint8_t groupId = GetDefaultGroup(); - if (networkUser != nullptr && networkUser->GroupId.has_value()) - { - const uint8_t assignedGroup = *networkUser->GroupId; - if (GetGroupByID(assignedGroup) != nullptr) - { - groupId = assignedGroup; - } - else - { - LOG_WARNING( - "User %s is assigned to non-existent group %u. Assigning to default group (%u)", keyhash.c_str(), assignedGroup, - groupId); - } - } - return groupId; -} - -uint8_t NetworkBase::GetDefaultGroup() const noexcept -{ - return default_group; -} - -void NetworkBase::SetDefaultGroup(uint8_t id) -{ - if (GetGroupByID(id) != nullptr) - { - default_group = id; - } -} - -void NetworkBase::SaveGroups() -{ - if (GetMode() == NETWORK_MODE_SERVER) - { - auto& env = GetContext().GetPlatformEnvironment(); - auto path = Path::Combine(env.GetDirectoryPath(DirBase::user), u8"groups.json"); - - json_t jsonGroups = json_t::array(); - for (auto& group : group_list) - { - jsonGroups.push_back(group->ToJson()); - } - json_t jsonGroupsCfg = { - { "default_group", default_group }, - { "groups", jsonGroups }, - }; + _listenSocket = CreateTcpSocket(); try { - Json::WriteToFile(path, jsonGroupsCfg); + _listenSocket->Listen(address, port); } catch (const std::exception& ex) { - LOG_ERROR("Unable to save %s: %s", path.c_str(), ex.what()); + Console::Error::WriteLine(ex.what()); + Close(); + return false; } - } -} -void NetworkBase::SetupDefaultGroups() -{ - // Admin group - auto admin = std::make_unique(); - admin->SetName("Admin"); - admin->ActionsAllowed.fill(0xFF); - admin->Id = 0; - group_list.push_back(std::move(admin)); + ServerName = Config::Get().network.ServerName; + ServerDescription = Config::Get().network.ServerDescription; + ServerGreeting = Config::Get().network.ServerGreeting; + ServerProviderName = Config::Get().network.ProviderName; + ServerProviderEmail = Config::Get().network.ProviderEmail; + ServerProviderWebsite = Config::Get().network.ProviderWebsite; - // Spectator group - auto spectator = std::make_unique(); - spectator->SetName("Spectator"); - spectator->ToggleActionPermission(NetworkPermission::Chat); - spectator->Id = 1; - group_list.push_back(std::move(spectator)); + IsServerPlayerInvisible = gOpenRCT2Headless; - // User group - auto user = std::make_unique(); - user->SetName("User"); - user->ActionsAllowed.fill(0xFF); - user->ToggleActionPermission(NetworkPermission::KickPlayer); - user->ToggleActionPermission(NetworkPermission::ModifyGroups); - user->ToggleActionPermission(NetworkPermission::SetPlayerGroup); - user->ToggleActionPermission(NetworkPermission::Cheat); - user->ToggleActionPermission(NetworkPermission::PasswordlessLogin); - user->ToggleActionPermission(NetworkPermission::ModifyTile); - user->ToggleActionPermission(NetworkPermission::EditScenarioOptions); - user->Id = 2; - group_list.push_back(std::move(user)); + LoadGroups(); + BeginChatLog(); + BeginServerLog(); - SetDefaultGroup(1); -} + NetworkPlayer* player = AddPlayer(Config::Get().network.PlayerName, ""); + player->Flags |= NETWORK_PLAYER_FLAG_ISSERVER; + player->Group = 0; + player_id = player->Id; -void NetworkBase::LoadGroups() -{ - group_list.clear(); - - auto& env = GetContext().GetPlatformEnvironment(); - auto path = Path::Combine(env.GetDirectoryPath(DirBase::user), u8"groups.json"); - - json_t jsonGroupConfig; - if (File::Exists(path)) - { - try + if (NetworkGetMode() == NETWORK_MODE_SERVER) { - jsonGroupConfig = Json::ReadFromFile(path); - } - catch (const std::exception& e) - { - LOG_ERROR("Failed to read %s as JSON. Setting default groups. %s", path.c_str(), e.what()); + // Add SERVER to users.json and save. + NetworkUser* networkUser = _userManager.GetOrAddUser(player->KeyHash); + networkUser->GroupId = player->Group; + networkUser->Name = player->Name; + _userManager.Save(); } + + auto* szAddress = address.empty() ? "*" : address.c_str(); + Console::WriteLine("Listening for clients on %s:%hu", szAddress, port); + NetworkChatShowConnectedMessage(); + NetworkChatShowServerGreeting(); + + status = NETWORK_STATUS_CONNECTED; + listening_port = port; + _serverState.gamestateSnapshotsEnabled = Config::Get().network.DesyncDebugging; + _advertiser = CreateServerAdvertiser(listening_port); + + GameLoadScripts(); + GameNotifyMapChanged(); + + return true; } - if (!jsonGroupConfig.is_object()) + int32_t NetworkBase::GetMode() const noexcept { - SetupDefaultGroups(); + return mode; } - else + + int32_t NetworkBase::GetStatus() const noexcept { - json_t jsonGroups = jsonGroupConfig["groups"]; - if (jsonGroups.is_array()) + return status; + } + + NetworkAuth NetworkBase::GetAuthStatus() + { + if (GetMode() == NETWORK_MODE_CLIENT) { - for (auto& jsonGroup : jsonGroups) + return _serverConnection->AuthStatus; + } + if (GetMode() == NETWORK_MODE_SERVER) + { + return NetworkAuth::Ok; + } + return NetworkAuth::None; + } + + uint32_t NetworkBase::GetServerTick() const noexcept + { + return _serverState.tick; + } + + uint8_t NetworkBase::GetPlayerID() const noexcept + { + return player_id; + } + + NetworkConnection* NetworkBase::GetPlayerConnection(uint8_t id) const + { + auto player = GetPlayerByID(id); + if (player != nullptr) + { + auto clientIt = std::find_if( + client_connection_list.begin(), client_connection_list.end(), + [player](const auto& conn) -> bool { return conn->Player == player; }); + return clientIt != client_connection_list.end() ? clientIt->get() : nullptr; + } + return nullptr; + } + + void NetworkBase::Update() + { + _closeLock = true; + + // Update is not necessarily called per game tick, maintain our own delta time + uint32_t ticks = Platform::GetTicks(); + _currentDeltaTime = std::max(ticks - _lastUpdateTime, 1); + _lastUpdateTime = ticks; + + switch (GetMode()) + { + case NETWORK_MODE_SERVER: + UpdateServer(); + break; + case NETWORK_MODE_CLIENT: + UpdateClient(); + break; + } + + // If the Close() was called during the update, close it for real + _closeLock = false; + if (_requireClose) + { + Close(); + if (_requireReconnect) { - group_list.emplace_back(std::make_unique(NetworkGroup::FromJson(jsonGroup))); + Reconnect(); } } + } - default_group = Json::GetNumber(jsonGroupConfig["default_group"]); - if (GetGroupByID(default_group) == nullptr) + void NetworkBase::Flush() + { + if (GetMode() == NETWORK_MODE_CLIENT) { - default_group = 0; - } - } - - // Host group should always contain all permissions. - group_list.at(0)->ActionsAllowed.fill(0xFF); -} - -std::string NetworkBase::BeginLog(const std::string& directory, const std::string& midName, const std::string& filenameFormat) -{ - utf8 filename[256]; - time_t timer; - time(&timer); - auto tmInfo = localtime(&timer); - if (strftime(filename, sizeof(filename), filenameFormat.c_str(), tmInfo) == 0) - { - throw std::runtime_error("strftime failed"); - } - - auto directoryMidName = Path::Combine(directory, midName); - Path::CreateDirectory(directoryMidName); - return Path::Combine(directoryMidName, filename); -} - -void NetworkBase::AppendLog(std::ostream& fs, std::string_view s) -{ - if (fs.fail()) - { - LOG_ERROR("bad ostream failed to append log"); - return; - } - try - { - utf8 buffer[1024]; - time_t timer; - time(&timer); - auto tmInfo = localtime(&timer); - if (strftime(buffer, sizeof(buffer), "[%Y/%m/%d %H:%M:%S] ", tmInfo) != 0) - { - String::append(buffer, sizeof(buffer), std::string(s).c_str()); - String::append(buffer, sizeof(buffer), PLATFORM_NEWLINE); - - fs.write(buffer, strlen(buffer)); - } - } - catch (const std::exception& ex) - { - LOG_ERROR("%s", ex.what()); - } -} - -void NetworkBase::BeginChatLog() -{ - auto& env = GetContext().GetPlatformEnvironment(); - auto directory = env.GetDirectoryPath(DirBase::user, DirId::chatLogs); - _chatLogPath = BeginLog(directory, "", _chatLogFilenameFormat); - _chat_log_fs.open(fs::u8path(_chatLogPath), std::ios::out | std::ios::app); -} - -void NetworkBase::AppendChatLog(std::string_view s) -{ - if (Config::Get().network.LogChat && _chat_log_fs.is_open()) - { - AppendLog(_chat_log_fs, s); - } -} - -void NetworkBase::CloseChatLog() -{ - _chat_log_fs.close(); -} - -void NetworkBase::BeginServerLog() -{ - auto& env = GetContext().GetPlatformEnvironment(); - auto directory = env.GetDirectoryPath(DirBase::user, DirId::serverLogs); - _serverLogPath = BeginLog(directory, ServerName, _serverLogFilenameFormat); - _server_log_fs.open(fs::u8path(_serverLogPath), std::ios::out | std::ios::app | std::ios::binary); - - // Log server start event - utf8 logMessage[256]; - if (GetMode() == NETWORK_MODE_CLIENT) - { - FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_CLIENT_STARTED, nullptr); - } - else if (GetMode() == NETWORK_MODE_SERVER) - { - FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_SERVER_STARTED, nullptr); - } - else - { - logMessage[0] = '\0'; - Guard::Assert(false, "Unknown network mode!"); - } - AppendServerLog(logMessage); -} - -void NetworkBase::AppendServerLog(const std::string& s) -{ - if (Config::Get().network.LogServerActions && _server_log_fs.is_open()) - { - AppendLog(_server_log_fs, s); - } -} - -void NetworkBase::CloseServerLog() -{ - // Log server stopped event - char logMessage[256]; - if (GetMode() == NETWORK_MODE_CLIENT) - { - FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_CLIENT_STOPPED, nullptr); - } - else if (GetMode() == NETWORK_MODE_SERVER) - { - FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_SERVER_STOPPED, nullptr); - } - else - { - logMessage[0] = '\0'; - Guard::Assert(false, "Unknown network mode!"); - } - AppendServerLog(logMessage); - _server_log_fs.close(); -} - -void NetworkBase::Client_Send_RequestGameState(uint32_t tick) -{ - if (_serverState.gamestateSnapshotsEnabled == false) - { - LOG_VERBOSE("Server does not store a gamestate history"); - return; - } - - LOG_VERBOSE("Requesting gamestate from server for tick %u", tick); - - NetworkPacket packet(NetworkCommand::RequestGameState); - packet << tick; - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::Client_Send_TOKEN() -{ - LOG_VERBOSE("requesting token"); - NetworkPacket packet(NetworkCommand::Token); - _serverConnection->AuthStatus = NetworkAuth::Requested; - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::Client_Send_AUTH( - const std::string& name, const std::string& password, const std::string& pubkey, const std::vector& signature) -{ - NetworkPacket packet(NetworkCommand::Auth); - packet.WriteString(NetworkGetVersion()); - packet.WriteString(name); - packet.WriteString(password); - packet.WriteString(pubkey); - assert(signature.size() <= static_cast(UINT32_MAX)); - packet << static_cast(signature.size()); - packet.Write(signature.data(), signature.size()); - _serverConnection->AuthStatus = NetworkAuth::Requested; - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::Client_Send_MAPREQUEST(const std::vector& objects) -{ - LOG_VERBOSE("client requests %u objects", uint32_t(objects.size())); - NetworkPacket packet(NetworkCommand::MapRequest); - packet << static_cast(objects.size()); - for (const auto& object : objects) - { - std::string name(object.GetName()); - LOG_VERBOSE("client requests object %s", name.c_str()); - if (object.Generation == ObjectGeneration::DAT) - { - packet << static_cast(0); - packet.Write(&object.Entry, sizeof(RCTObjectEntry)); + _serverConnection->SendQueuedData(); } else { - packet << static_cast(1); - packet.WriteString(name); + for (auto& it : client_connection_list) + { + it->SendQueuedData(); + } } } - _serverConnection->QueuePacket(std::move(packet)); -} -void NetworkBase::ServerSendToken(NetworkConnection& connection) -{ - NetworkPacket packet(NetworkCommand::Token); - packet << static_cast(connection.Challenge.size()); - packet.Write(connection.Challenge.data(), connection.Challenge.size()); - connection.QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendObjectsList( - NetworkConnection& connection, const std::vector& objects) const -{ - LOG_VERBOSE("Server sends objects list with %u items", objects.size()); - - if (objects.empty()) - { - NetworkPacket packet(NetworkCommand::ObjectsList); - packet << static_cast(0) << static_cast(objects.size()); - - connection.QueuePacket(std::move(packet)); - } - else - { - for (size_t i = 0; i < objects.size(); ++i) - { - const auto* object = objects[i]; - - NetworkPacket packet(NetworkCommand::ObjectsList); - packet << static_cast(i) << static_cast(objects.size()); - - if (object->Identifier.empty()) - { - // DAT - LOG_VERBOSE("Object %.8s (checksum %x)", object->ObjectEntry.name, object->ObjectEntry.checksum); - packet << static_cast(0); - packet.Write(&object->ObjectEntry, sizeof(RCTObjectEntry)); - } - else - { - // JSON - LOG_VERBOSE("Object %s", object->Identifier.c_str()); - packet << static_cast(1); - packet.WriteString(object->Identifier); - } - - connection.QueuePacket(std::move(packet)); - } - } -} - -void NetworkBase::ServerSendScripts(NetworkConnection& connection) -{ - #ifdef ENABLE_SCRIPTING - using namespace OpenRCT2::Scripting; - - auto& scriptEngine = GetContext().GetScriptEngine(); - - // Get remote plugin list. - const auto remotePlugins = scriptEngine.GetRemotePlugins(); - LOG_VERBOSE("Server sends %zu scripts", remotePlugins.size()); - - // Build the data contents for each plugin. - MemoryStream pluginData; - for (auto& plugin : remotePlugins) - { - const auto& code = plugin->GetCode(); - - const auto codeSize = static_cast(code.size()); - pluginData.WriteValue(codeSize); - pluginData.WriteArray(code.c_str(), code.size()); - } - - // Send the header packet. - NetworkPacket packetScriptHeader(NetworkCommand::ScriptsHeader); - packetScriptHeader << static_cast(remotePlugins.size()); - packetScriptHeader << static_cast(pluginData.GetLength()); - connection.QueuePacket(std::move(packetScriptHeader)); - - // Segment the plugin data into chunks and send them. - const uint8_t* pluginDataBuffer = static_cast(pluginData.GetData()); - uint32_t dataOffset = 0; - while (dataOffset < pluginData.GetLength()) - { - const uint32_t chunkSize = std::min(pluginData.GetLength() - dataOffset, kChunkSize); - - NetworkPacket packet(NetworkCommand::ScriptsData); - packet << chunkSize; - packet.Write(pluginDataBuffer + dataOffset, chunkSize); - - connection.QueuePacket(std::move(packet)); - - dataOffset += chunkSize; - } - Guard::Assert(dataOffset == pluginData.GetLength()); - - #else - NetworkPacket packetScriptHeader(NetworkCommand::ScriptsHeader); - packetScriptHeader << static_cast(0u); - packetScriptHeader << static_cast(0u); - #endif -} - -void NetworkBase::Client_Send_HEARTBEAT(NetworkConnection& connection) const -{ - LOG_VERBOSE("Sending heartbeat"); - - NetworkPacket packet(NetworkCommand::Heartbeat); - connection.QueuePacket(std::move(packet)); -} - -NetworkStats NetworkBase::GetStats() const -{ - NetworkStats stats = {}; - if (mode == NETWORK_MODE_CLIENT) - { - stats = _serverConnection->Stats; - } - else + void NetworkBase::UpdateServer() { for (auto& connection : client_connection_list) { - for (size_t n = 0; n < EnumValue(NetworkStatisticsGroup::Max); n++) + // This can be called multiple times before the connection is removed. + if (!connection->IsValid()) + continue; + + if (!ProcessConnection(*connection)) { - stats.bytesReceived[n] += connection->Stats.bytesReceived[n]; - stats.bytesSent[n] += connection->Stats.bytesSent[n]; + connection->Disconnect(); + } + else + { + DecayCooldown(connection->Player); } } - } - return stats; -} -void NetworkBase::ServerSendAuth(NetworkConnection& connection) -{ - uint8_t new_playerid = 0; - if (connection.Player != nullptr) - { - new_playerid = connection.Player->Id; - } - NetworkPacket packet(NetworkCommand::Auth); - packet << static_cast(connection.AuthStatus) << new_playerid; - if (connection.AuthStatus == NetworkAuth::BadVersion) - { - packet.WriteString(NetworkGetVersion()); - } - connection.QueuePacket(std::move(packet)); - if (connection.AuthStatus != NetworkAuth::Ok && connection.AuthStatus != NetworkAuth::RequirePassword) - { - connection.Disconnect(); - } -} - -void NetworkBase::ServerSendMap(NetworkConnection* connection) -{ - std::vector objects; - if (connection != nullptr) - { - objects = connection->RequestedObjects; - } - else - { - // This will send all custom objects to connected clients - // TODO: fix it so custom objects negotiation is performed even in this case. - auto& context = GetContext(); - auto& objManager = context.GetObjectManager(); - objects = objManager.GetPackableObjects(); - } - - auto header = SaveForNetwork(objects); - if (header.empty()) - { - if (connection != nullptr) + uint32_t ticks = Platform::GetTicks(); + if (ticks > last_ping_sent_time + 3000) { - connection->SetLastDisconnectReason(STR_MULTIPLAYER_CONNECTION_CLOSED); - connection->Disconnect(); + ServerSendPing(); + ServerSendPingList(); } - return; - } - size_t chunksize = kChunkSize; - for (size_t i = 0; i < header.size(); i += chunksize) - { - size_t datasize = std::min(chunksize, header.size() - i); - NetworkPacket packet(NetworkCommand::Map); - packet << static_cast(header.size()) << static_cast(i); - packet.Write(&header[i], datasize); - if (connection != nullptr) + + if (_advertiser != nullptr) { - connection->QueuePacket(std::move(packet)); + _advertiser->Update(); } - else + + std::unique_ptr tcpSocket = _listenSocket->Accept(); + if (tcpSocket != nullptr) { - SendPacketToClients(packet); + AddClient(std::move(tcpSocket)); } } -} -std::vector NetworkBase::SaveForNetwork(const std::vector& objects) const -{ - std::vector result; - auto ms = MemoryStream(); - if (SaveMap(&ms, objects)) + void NetworkBase::UpdateClient() { - result.resize(ms.GetLength()); - std::memcpy(result.data(), ms.GetData(), result.size()); - } - else - { - LOG_WARNING("Failed to export map."); - } - return result; -} + assert(_serverConnection != nullptr); -void NetworkBase::Client_Send_CHAT(const char* text) -{ - NetworkPacket packet(NetworkCommand::Chat); - packet.WriteString(text); - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendChat(const char* text, const std::vector& playerIds) -{ - NetworkPacket packet(NetworkCommand::Chat); - packet.WriteString(text); - - if (playerIds.empty()) - { - // Empty players / default value means send to all players - SendPacketToClients(packet); - } - else - { - for (auto playerId : playerIds) + switch (status) { - auto conn = GetPlayerConnection(playerId); - if (conn != nullptr) + case NETWORK_STATUS_CONNECTING: { - conn->QueuePacket(packet); - } - } - } -} - -void NetworkBase::Client_Send_GAME_ACTION(const GameActions::GameAction* action) -{ - NetworkPacket packet(NetworkCommand::GameAction); - - uint32_t networkId = 0; - networkId = ++_actionId; - - // I know its ugly, want basic functionality for now. - const_cast(action)->SetNetworkId(networkId); - if (action->GetCallback()) - { - _gameActionCallbacks.insert(std::make_pair(networkId, action->GetCallback())); - } - - DataSerialiser stream(true); - action->Serialise(stream); - - packet << getGameState().currentTicks << action->GetType() << stream; - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendGameAction(const GameActions::GameAction* action) -{ - NetworkPacket packet(NetworkCommand::GameAction); - - DataSerialiser stream(true); - action->Serialise(stream); - - packet << getGameState().currentTicks << action->GetType() << stream; - - SendPacketToClients(packet); -} - -void NetworkBase::ServerSendTick() -{ - NetworkPacket packet(NetworkCommand::Tick); - packet << getGameState().currentTicks << ScenarioRandState().s0; - uint32_t flags = 0; - // Simple counter which limits how often a sprite checksum gets sent. - // This can get somewhat expensive, so we don't want to push it every tick in release, - // but debug version can check more often. - static int32_t checksum_counter = 0; - checksum_counter++; - if (checksum_counter >= 100) - { - checksum_counter = 0; - flags |= NETWORK_TICK_FLAG_CHECKSUMS; - } - // Send flags always, so we can understand packet structure on the other end, - // and allow for some expansion. - packet << flags; - if (flags & NETWORK_TICK_FLAG_CHECKSUMS) - { - EntitiesChecksum checksum = getGameState().entities.GetAllEntitiesChecksum(); - packet.WriteString(checksum.ToString()); - } - - SendPacketToClients(packet); -} - -void NetworkBase::ServerSendPlayerInfo(int32_t playerId) -{ - NetworkPacket packet(NetworkCommand::PlayerInfo); - packet << getGameState().currentTicks; - - auto* player = GetPlayerByID(playerId); - if (player == nullptr) - return; - - player->Write(packet); - SendPacketToClients(packet); -} - -void NetworkBase::ServerSendPlayerList() -{ - NetworkPacket packet(NetworkCommand::PlayerList); - packet << getGameState().currentTicks << static_cast(player_list.size()); - for (auto& player : player_list) - { - player->Write(packet); - } - SendPacketToClients(packet); -} - -void NetworkBase::Client_Send_PING() -{ - NetworkPacket packet(NetworkCommand::Ping); - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendPing() -{ - last_ping_sent_time = Platform::GetTicks(); - NetworkPacket packet(NetworkCommand::Ping); - for (auto& client_connection : client_connection_list) - { - client_connection->PingTime = Platform::GetTicks(); - } - SendPacketToClients(packet, true); -} - -void NetworkBase::ServerSendPingList() -{ - NetworkPacket packet(NetworkCommand::PingList); - packet << static_cast(player_list.size()); - for (auto& player : player_list) - { - packet << player->Id << player->Ping; - } - SendPacketToClients(packet); -} - -void NetworkBase::ServerSendSetDisconnectMsg(NetworkConnection& connection, const char* msg) -{ - NetworkPacket packet(NetworkCommand::DisconnectMessage); - packet.WriteString(msg); - connection.QueuePacket(std::move(packet)); -} - -json_t NetworkBase::GetServerInfoAsJson() const -{ - json_t jsonObj = { - { "name", Config::Get().network.ServerName }, - { "requiresPassword", _password.size() > 0 }, - { "version", NetworkGetVersion() }, - { "players", GetNumVisiblePlayers() }, - { "maxPlayers", Config::Get().network.Maxplayers }, - { "description", Config::Get().network.ServerDescription }, - { "greeting", Config::Get().network.ServerGreeting }, - { "dedicated", gOpenRCT2Headless }, - }; - return jsonObj; -} - -void NetworkBase::ServerSendGameInfo(NetworkConnection& connection) -{ - NetworkPacket packet(NetworkCommand::GameInfo); - #ifndef DISABLE_HTTP - json_t jsonObj = GetServerInfoAsJson(); - - // Provider details - json_t jsonProvider = { - { "name", Config::Get().network.ProviderName }, - { "email", Config::Get().network.ProviderEmail }, - { "website", Config::Get().network.ProviderWebsite }, - }; - - jsonObj["provider"] = jsonProvider; - - packet.WriteString(jsonObj.dump()); - packet << _serverState.gamestateSnapshotsEnabled; - packet << IsServerPlayerInvisible; - - #endif - connection.QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendShowError(NetworkConnection& connection, StringId title, StringId message) -{ - NetworkPacket packet(NetworkCommand::ShowError); - packet << title << message; - connection.QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendGroupList(NetworkConnection& connection) -{ - NetworkPacket packet(NetworkCommand::GroupList); - packet << static_cast(group_list.size()) << default_group; - for (auto& group : group_list) - { - group->Write(packet); - } - connection.QueuePacket(std::move(packet)); -} - -void NetworkBase::ServerSendEventPlayerJoined(const char* playerName) -{ - NetworkPacket packet(NetworkCommand::Event); - packet << static_cast(SERVER_EVENT_PLAYER_JOINED); - packet.WriteString(playerName); - SendPacketToClients(packet); -} - -void NetworkBase::ServerSendEventPlayerDisconnected(const char* playerName, const char* reason) -{ - NetworkPacket packet(NetworkCommand::Event); - packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); - packet.WriteString(playerName); - packet.WriteString(reason); - SendPacketToClients(packet); -} - -bool NetworkBase::ProcessConnection(NetworkConnection& connection) -{ - NetworkReadPacket packetStatus; - - uint32_t countProcessed = 0; - do - { - countProcessed++; - packetStatus = connection.ReadPacket(); - switch (packetStatus) - { - case NetworkReadPacket::Disconnected: - // closed connection or network error - if (!connection.GetLastDisconnectReason()) + switch (_serverConnection->Socket->GetStatus()) { - connection.SetLastDisconnectReason(STR_MULTIPLAYER_CONNECTION_CLOSED); - } - return false; - case NetworkReadPacket::Success: - // done reading in packet - ProcessPacket(connection, connection.InboundPacket); - if (!connection.IsValid()) - { - return false; - } - break; - case NetworkReadPacket::MoreData: - // more data required to be read - break; - case NetworkReadPacket::NoData: - // could not read anything from socket - break; - } - } while (packetStatus == NetworkReadPacket::Success && countProcessed < kMaxPacketsPerUpdate); - - if (!connection.ReceivedPacketRecently()) - { - if (!connection.GetLastDisconnectReason()) - { - connection.SetLastDisconnectReason(STR_MULTIPLAYER_NO_DATA); - } - return false; - } - - return true; -} - -void NetworkBase::ProcessPacket(NetworkConnection& connection, NetworkPacket& packet) -{ - const auto& handlerList = GetMode() == NETWORK_MODE_SERVER ? server_command_handlers : client_command_handlers; - - auto it = handlerList.find(packet.GetCommand()); - if (it != handlerList.end()) - { - auto commandHandler = it->second; - if (connection.AuthStatus == NetworkAuth::Ok || !packet.CommandRequiresAuth()) - { - try - { - (this->*commandHandler)(connection, packet); - } - catch (const std::exception& ex) - { - LOG_VERBOSE("Exception during packet processing: %s", ex.what()); - } - } - } - - packet.Clear(); -} - -// This is called at the end of each game tick, this where things should be processed that affects the game state. -void NetworkBase::ProcessPending() -{ - if (GetMode() == NETWORK_MODE_SERVER) - { - ProcessDisconnectedClients(); - } - else if (GetMode() == NETWORK_MODE_CLIENT) - { - ProcessPlayerInfo(); - } - ProcessPlayerList(); -} - -static bool ProcessPlayerAuthenticatePluginHooks( - const NetworkConnection& connection, std::string_view name, std::string_view publicKeyHash) -{ - #ifdef ENABLE_SCRIPTING - using namespace OpenRCT2::Scripting; - - auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); - if (hookEngine.HasSubscriptions(Scripting::HookType::networkAuthenticate)) - { - auto ctx = GetContext()->GetScriptEngine().GetContext(); - - // Create event args object - DukObject eObj(ctx); - eObj.Set("name", name); - eObj.Set("publicKeyHash", publicKeyHash); - eObj.Set("ipAddress", connection.Socket->GetIpAddress()); - eObj.Set("cancel", false); - auto e = eObj.Take(); - - // Call the subscriptions - hookEngine.Call(Scripting::HookType::networkAuthenticate, e, false); - - // Check if any hook has cancelled the join - if (AsOrDefault(e["cancel"], false)) - { - return false; - } - } - #endif - return true; -} - -static void ProcessPlayerJoinedPluginHooks(uint8_t playerId) -{ - #ifdef ENABLE_SCRIPTING - using namespace OpenRCT2::Scripting; - - auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); - if (hookEngine.HasSubscriptions(Scripting::HookType::networkJoin)) - { - auto ctx = GetContext()->GetScriptEngine().GetContext(); - - // Create event args object - DukObject eObj(ctx); - eObj.Set("player", playerId); - auto e = eObj.Take(); - - // Call the subscriptions - hookEngine.Call(Scripting::HookType::networkJoin, e, false); - } - #endif -} - -static void ProcessPlayerLeftPluginHooks(uint8_t playerId) -{ - #ifdef ENABLE_SCRIPTING - using namespace OpenRCT2::Scripting; - - auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); - if (hookEngine.HasSubscriptions(Scripting::HookType::networkLeave)) - { - auto ctx = GetContext()->GetScriptEngine().GetContext(); - - // Create event args object - DukObject eObj(ctx); - eObj.Set("player", playerId); - auto e = eObj.Take(); - - // Call the subscriptions - hookEngine.Call(Scripting::HookType::networkLeave, e, false); - } - #endif -} - -void NetworkBase::ProcessPlayerList() -{ - if (GetMode() == NETWORK_MODE_SERVER) - { - // Avoid sending multiple times the player list, we mark the list invalidated on modifications - // and then send at the end of the tick the final player list. - if (_playerListInvalidated) - { - _playerListInvalidated = false; - ServerSendPlayerList(); - } - } - else - { - // As client we have to keep things in order so the update is tick bound. - // Commands/Actions reference players and so this list needs to be in sync with those. - auto itPending = _pendingPlayerLists.begin(); - while (itPending != _pendingPlayerLists.end()) - { - if (itPending->first > getGameState().currentTicks) - break; - - // List of active players found in the list. - std::vector activePlayerIds; - std::vector newPlayers; - std::vector removedPlayers; - - for (const auto& pendingPlayer : itPending->second.players) - { - activePlayerIds.push_back(pendingPlayer.Id); - - auto* player = GetPlayerByID(pendingPlayer.Id); - if (player == nullptr) - { - // Add new player. - player = AddPlayer("", ""); - if (player != nullptr) + case SocketStatus::Resolving: { - *player = pendingPlayer; - if (player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) + if (_lastConnectStatus != SocketStatus::Resolving) { - _serverConnection->Player = player; + _lastConnectStatus = SocketStatus::Resolving; + char str_resolving[256]; + FormatStringLegacy(str_resolving, 256, STR_MULTIPLAYER_RESOLVING, nullptr); + + auto intent = Intent(WindowClass::NetworkStatus); + intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_resolving }); + intent.PutExtra( + INTENT_EXTRA_CALLBACK, []() -> void { OpenRCT2::GetContext()->GetNetwork().Close(); }); + ContextOpenIntent(&intent); } - newPlayers.push_back(player->Id); + break; } + case SocketStatus::Connecting: + { + if (_lastConnectStatus != SocketStatus::Connecting) + { + _lastConnectStatus = SocketStatus::Connecting; + char str_connecting[256]; + FormatStringLegacy(str_connecting, 256, STR_MULTIPLAYER_CONNECTING, nullptr); + + auto intent = Intent(WindowClass::NetworkStatus); + intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_connecting }); + intent.PutExtra( + INTENT_EXTRA_CALLBACK, []() -> void { OpenRCT2::GetContext()->GetNetwork().Close(); }); + ContextOpenIntent(&intent); + + server_connect_time = Platform::GetTicks(); + } + break; + } + case SocketStatus::Connected: + { + status = NETWORK_STATUS_CONNECTED; + _serverConnection->ResetLastPacketTime(); + Client_Send_TOKEN(); + char str_authenticating[256]; + FormatStringLegacy(str_authenticating, 256, STR_MULTIPLAYER_AUTHENTICATING, nullptr); + + auto intent = Intent(WindowClass::NetworkStatus); + intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_authenticating }); + intent.PutExtra(INTENT_EXTRA_CALLBACK, []() -> void { ::GetContext()->GetNetwork().Close(); }); + ContextOpenIntent(&intent); + break; + } + default: + { + const char* error = _serverConnection->Socket->GetError(); + if (error != nullptr) + { + Console::Error::WriteLine(error); + } + + Close(); + ContextForceCloseWindowByClass(WindowClass::NetworkStatus); + ContextShowError(STR_UNABLE_TO_CONNECT_TO_SERVER, kStringIdNone, {}); + break; + } + } + break; + } + case NETWORK_STATUS_CONNECTED: + { + if (!ProcessConnection(*_serverConnection)) + { + // Do not show disconnect message window when password window closed/canceled + if (_serverConnection->AuthStatus == NetworkAuth::RequirePassword) + { + ContextForceCloseWindowByClass(WindowClass::NetworkStatus); + } + else + { + char str_disconnected[256]; + + if (_serverConnection->GetLastDisconnectReason()) + { + const char* disconnect_reason = _serverConnection->GetLastDisconnectReason(); + FormatStringLegacy( + str_disconnected, 256, STR_MULTIPLAYER_DISCONNECTED_WITH_REASON, &disconnect_reason); + } + else + { + FormatStringLegacy(str_disconnected, 256, STR_MULTIPLAYER_DISCONNECTED_NO_REASON, nullptr); + } + + auto intent = Intent(WindowClass::NetworkStatus); + intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_disconnected }); + ContextOpenIntent(&intent); + } + + auto* windowMgr = Ui::GetWindowManager(); + windowMgr->CloseByClass(WindowClass::Multiplayer); + Close(); } else { - // Update. - *player = pendingPlayer; + uint32_t ticks = Platform::GetTicks(); + if (ticks - _lastSentHeartbeat >= 3000) + { + Client_Send_HEARTBEAT(*_serverConnection); + _lastSentHeartbeat = ticks; + } } - } - // Remove any players that are not in newly received list - for (const auto& player : player_list) + break; + } + } + } + + auto NetworkBase::GetPlayerIteratorByID(uint8_t id) const + { + return std::find_if(player_list.begin(), player_list.end(), [id](std::unique_ptr const& player) { + return player->Id == id; + }); + } + + NetworkPlayer* NetworkBase::GetPlayerByID(uint8_t id) const + { + auto it = GetPlayerIteratorByID(id); + if (it != player_list.end()) + { + return it->get(); + } + return nullptr; + } + + auto NetworkBase::GetGroupIteratorByID(uint8_t id) const + { + return std::find_if( + group_list.begin(), group_list.end(), [id](std::unique_ptr const& group) { return group->Id == id; }); + } + + NetworkGroup* NetworkBase::GetGroupByID(uint8_t id) const + { + auto it = GetGroupIteratorByID(id); + if (it != group_list.end()) + { + return it->get(); + } + return nullptr; + } + + int32_t NetworkBase::GetTotalNumPlayers() const noexcept + { + return static_cast(player_list.size()); + } + + int32_t NetworkBase::GetNumVisiblePlayers() const noexcept + { + if (IsServerPlayerInvisible) + return static_cast(player_list.size() - 1); + return static_cast(player_list.size()); + } + + const char* NetworkBase::FormatChat(NetworkPlayer* fromPlayer, const char* text) + { + static std::string formatted; + formatted.clear(); + + if (fromPlayer != nullptr) + { + auto& network = OpenRCT2::GetContext()->GetNetwork(); + auto it = network.GetGroupByID(fromPlayer->Id); + std::string groupName = ""; + std::vector colours; + if (it != nullptr) { - if (std::find(activePlayerIds.begin(), activePlayerIds.end(), player->Id) == activePlayerIds.end()) + groupName = it->GetName(); + if (groupName[0] != '{') { - removedPlayers.push_back(player->Id); + colours.push_back("{WHITE}"); } } - // Run player removed hooks (must be before players removed from list) - for (auto playerId : removedPlayers) + for (size_t i = 0; i < groupName.size(); ++i) { - ProcessPlayerLeftPluginHooks(playerId); + if (groupName[i] == '{') + { + std::string colour = "{"; + ++i; + while (i < groupName.size() && groupName[i] != '}' && groupName[i] != '{') + { + colour += groupName[i]; + ++i; + } + colour += '}'; + if (groupName[i] == '}' && i < groupName.size()) + { + colours.push_back(colour); + } + } } - // Run player joined hooks (must be after players added to list) - for (auto playerId : newPlayers) + if (colours.size() == 0 || (colours.size() == 1 && colours[0] == "{WHITE}")) { - ProcessPlayerJoinedPluginHooks(playerId); + formatted += "{BABYBLUE}"; + formatted += fromPlayer->Name; + } + else + { + size_t j = 0; + size_t proportionalSize = fromPlayer->Name.size() / colours.size(); + for (size_t i = 0; i < colours.size(); ++i) + { + formatted += colours[i]; + size_t numCharacters = proportionalSize + j; + for (; j < numCharacters && j < fromPlayer->Name.size(); ++j) + { + formatted += fromPlayer->Name[j]; + } + } + while (j < fromPlayer->Name.size()) + { + formatted += fromPlayer->Name[j]; + j++; + } } - // Now actually remove removed players from player list - player_list.erase( - std::remove_if( - player_list.begin(), player_list.end(), - [&removedPlayers](const std::unique_ptr& player) { - return std::find(removedPlayers.begin(), removedPlayers.end(), player->Id) != removedPlayers.end(); - }), - player_list.end()); - - _pendingPlayerLists.erase(itPending); - itPending = _pendingPlayerLists.begin(); + formatted += ": "; } + formatted += "{WHITE}"; + formatted += text; + return formatted.c_str(); } -} -void NetworkBase::ProcessPlayerInfo() -{ - const auto currentTicks = getGameState().currentTicks; - - auto range = _pendingPlayerInfo.equal_range(currentTicks); - for (auto it = range.first; it != range.second; it++) + void NetworkBase::SendPacketToClients(const NetworkPacket& packet, bool front, bool gameCmd) const { - auto* player = GetPlayerByID(it->second.Id); - if (player != nullptr) + for (auto& client_connection : client_connection_list) { - const NetworkPlayer& networkedInfo = it->second; - player->Flags = networkedInfo.Flags; - player->Group = networkedInfo.Group; - player->LastAction = networkedInfo.LastAction; - player->LastActionCoord = networkedInfo.LastActionCoord; - player->MoneySpent = networkedInfo.MoneySpent; - player->CommandsRan = networkedInfo.CommandsRan; + if (gameCmd) + { + // If marked as game command we can not send the packet to connections that are not fully connected. + // Sending the packet would cause the client to store a command that is behind the tick where he starts, + // which would be essentially never executed. The clients do not require commands before the server has not sent + // the map data. + if (client_connection->Player == nullptr) + { + continue; + } + } + client_connection->QueuePacket(packet, front); } } - _pendingPlayerInfo.erase(currentTicks); -} -void NetworkBase::ProcessDisconnectedClients() -{ - for (auto it = client_connection_list.begin(); it != client_connection_list.end();) + bool NetworkBase::CheckSRAND(uint32_t tick, uint32_t srand0) { - auto& connection = *it; + // We have to wait for the map to be loaded first, ticks may match current loaded map. + if (!_clientMapLoaded) + return true; - if (!connection->ShouldDisconnect) + auto itTickData = _serverTickData.find(tick); + if (itTickData == std::end(_serverTickData)) + return true; + + const ServerTickData storedTick = itTickData->second; + _serverTickData.erase(itTickData); + + if (storedTick.srand0 != srand0) { - it++; - continue; + LOG_INFO("Srand0 mismatch, client = %08X, server = %08X", srand0, storedTick.srand0); + return false; } - // Make sure to send all remaining packets out before disconnecting. - connection->SendQueuedData(); - connection->Socket->Disconnect(); + if (!storedTick.spriteHash.empty()) + { + EntitiesChecksum checksum = getGameState().entities.GetAllEntitiesChecksum(); + std::string clientSpriteHash = checksum.ToString(); + if (clientSpriteHash != storedTick.spriteHash) + { + LOG_INFO( + "Sprite hash mismatch, client = %s, server = %s", clientSpriteHash.c_str(), storedTick.spriteHash.c_str()); + return false; + } + } - ServerClientDisconnected(connection); - RemovePlayer(connection); - - it = client_connection_list.erase(it); - } -} - -void NetworkBase::AddClient(std::unique_ptr&& socket) -{ - // Log connection info. - char addr[128]; - snprintf(addr, sizeof(addr), "Client joined from %s", socket->GetHostName()); - AppendServerLog(addr); - - // Store connection - auto connection = std::make_unique(); - connection->Socket = std::move(socket); - - client_connection_list.push_back(std::move(connection)); -} - -void NetworkBase::ServerClientDisconnected(std::unique_ptr& connection) -{ - NetworkPlayer* connection_player = connection->Player; - if (connection_player == nullptr) - return; - - char text[256]; - const char* has_disconnected_args[2] = { - connection_player->Name.c_str(), - connection->GetLastDisconnectReason(), - }; - if (has_disconnected_args[1] != nullptr) - { - FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_WITH_REASON, has_disconnected_args); - } - else - { - FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_NO_REASON, &(has_disconnected_args[0])); + return true; } - ChatAddHistory(text); - Peep* pickup_peep = NetworkGetPickupPeep(connection_player->Id); - if (pickup_peep != nullptr) + bool NetworkBase::IsDesynchronised() const noexcept { - GameActions::PeepPickupAction pickupAction{ GameActions::PeepPickupType::Cancel, - pickup_peep->Id, - { NetworkGetPickupPeepOldX(connection_player->Id), 0, 0 }, - NetworkGetCurrentPlayerId() }; - auto res = GameActions::Execute(&pickupAction); + return _serverState.state == NetworkServerStatus::Desynced; } - ServerSendEventPlayerDisconnected( - const_cast(connection_player->Name.c_str()), connection->GetLastDisconnectReason()); - // Log player disconnected event - AppendServerLog(text); - - ProcessPlayerLeftPluginHooks(connection_player->Id); -} - -void NetworkBase::RemovePlayer(std::unique_ptr& connection) -{ - NetworkPlayer* connection_player = connection->Player; - if (connection_player == nullptr) - return; - - player_list.erase( - std::remove_if( - player_list.begin(), player_list.end(), - [connection_player](std::unique_ptr& player) { return player.get() == connection_player; }), - player_list.end()); - - // Send new player list. - _playerListInvalidated = true; -} - -NetworkPlayer* NetworkBase::AddPlayer(const std::string& name, const std::string& keyhash) -{ - NetworkPlayer* addedplayer = nullptr; - int32_t newid = -1; - if (GetMode() == NETWORK_MODE_SERVER) + bool NetworkBase::CheckDesynchronizaton() { - // Find first unused player id + const auto currentTicks = getGameState().currentTicks; + + // Check synchronisation + if (GetMode() == NETWORK_MODE_CLIENT && _serverState.state != NetworkServerStatus::Desynced + && !CheckSRAND(currentTicks, ScenarioRandState().s0)) + { + _serverState.state = NetworkServerStatus::Desynced; + _serverState.desyncTick = currentTicks; + + char str_desync[256]; + FormatStringLegacy(str_desync, 256, STR_MULTIPLAYER_DESYNC, nullptr); + + auto intent = Intent(WindowClass::NetworkStatus); + intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_desync }); + ContextOpenIntent(&intent); + + if (!Config::Get().network.StayConnected) + { + Close(); + } + + return true; + } + + return false; + } + + void NetworkBase::RequestStateSnapshot() + { + LOG_INFO("Requesting game state for tick %u", _serverState.desyncTick); + + Client_Send_RequestGameState(_serverState.desyncTick); + } + + NetworkServerState NetworkBase::GetServerState() const noexcept + { + return _serverState; + } + + void NetworkBase::KickPlayer(int32_t playerId) + { + for (auto& client_connection : client_connection_list) + { + if (client_connection->Player->Id == playerId) + { + // Disconnect the client gracefully + client_connection->SetLastDisconnectReason(STR_MULTIPLAYER_KICKED); + char str_disconnect_msg[256]; + FormatStringLegacy(str_disconnect_msg, 256, STR_MULTIPLAYER_KICKED_REASON, nullptr); + ServerSendSetDisconnectMsg(*client_connection, str_disconnect_msg); + client_connection->Disconnect(); + break; + } + } + } + + void NetworkBase::SetPassword(u8string_view password) + { + _password = password; + } + + void NetworkBase::ServerClientDisconnected() + { + if (GetMode() == NETWORK_MODE_CLIENT) + { + _serverConnection->Disconnect(); + } + } + + std::string NetworkBase::GenerateAdvertiseKey() + { + // Generate a string of 16 random hex characters (64-integer key as a hex formatted string) + static char hexChars[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', + }; + char key[17]; + for (int32_t i = 0; i < 16; i++) + { + int32_t hexCharIndex = UtilRand() % std::size(hexChars); + key[i] = hexChars[hexCharIndex]; + } + key[std::size(key) - 1] = 0; + + return key; + } + + std::string NetworkBase::GetMasterServerUrl() + { + if (Config::Get().network.MasterServerUrl.empty()) + { + return kMasterServerURL; + } + + return Config::Get().network.MasterServerUrl; + } + + NetworkGroup* NetworkBase::AddGroup() + { + NetworkGroup* addedgroup = nullptr; + int32_t newid = -1; + // Find first unused group id for (int32_t id = 0; id < 255; id++) { if (std::find_if( - player_list.begin(), player_list.end(), - [&id](std::unique_ptr const& player) { return player->Id == id; }) - == player_list.end()) + group_list.begin(), group_list.end(), + [&id](std::unique_ptr const& group) { return group->Id == id; }) + == group_list.end()) { newid = id; break; } } + if (newid != -1) + { + auto group = std::make_unique(); + group->Id = newid; + group->SetName("Group #" + std::to_string(newid)); + addedgroup = group.get(); + group_list.push_back(std::move(group)); + } + return addedgroup; } - else + + void NetworkBase::RemoveGroup(uint8_t id) { - newid = 0; - } - if (newid != -1) - { - std::unique_ptr player; + auto group = GetGroupIteratorByID(id); + if (group != group_list.end()) + { + group_list.erase(group); + } + if (GetMode() == NETWORK_MODE_SERVER) { - // Load keys host may have added manually - _userManager.Load(); + _userManager.UnsetUsersOfGroup(id); + _userManager.Save(); + } + } - // Check if the key is registered - const NetworkUser* networkUser = _userManager.GetUserByHash(keyhash); + uint8_t NetworkBase::GetGroupIDByHash(const std::string& keyhash) + { + const NetworkUser* networkUser = _userManager.GetUserByHash(keyhash); - player = std::make_unique(); - player->Id = newid; - player->KeyHash = keyhash; - if (networkUser == nullptr) + uint8_t groupId = GetDefaultGroup(); + if (networkUser != nullptr && networkUser->GroupId.has_value()) + { + const uint8_t assignedGroup = *networkUser->GroupId; + if (GetGroupByID(assignedGroup) != nullptr) { - player->Group = GetDefaultGroup(); - if (!name.empty()) - { - player->SetName(MakePlayerNameUnique(String::trim(name))); - } + groupId = assignedGroup; } else { - player->Group = networkUser->GroupId.has_value() ? *networkUser->GroupId : GetDefaultGroup(); - player->SetName(networkUser->Name); + LOG_WARNING( + "User %s is assigned to non-existent group %u. Assigning to default group (%u)", keyhash.c_str(), + assignedGroup, groupId); } - - // Send new player list. - _playerListInvalidated = true; } - else - { - player = std::make_unique(); - player->Id = newid; - player->Group = GetDefaultGroup(); - player->SetName(String::trim(std::string(name))); - } - - addedplayer = player.get(); - player_list.push_back(std::move(player)); + return groupId; } - return addedplayer; -} -std::string NetworkBase::MakePlayerNameUnique(const std::string& name) -{ - // Note: Player names are case-insensitive - - std::string new_name = name.substr(0, 31); - int32_t counter = 1; - bool unique; - do + uint8_t NetworkBase::GetDefaultGroup() const noexcept { - unique = true; + return default_group; + } - // Check if there is already a player with this name in the server - for (const auto& player : player_list) + void NetworkBase::SetDefaultGroup(uint8_t id) + { + if (GetGroupByID(id) != nullptr) { - if (String::iequals(player->Name, new_name)) + default_group = id; + } + } + + void NetworkBase::SaveGroups() + { + if (GetMode() == NETWORK_MODE_SERVER) + { + auto& env = GetContext().GetPlatformEnvironment(); + auto path = Path::Combine(env.GetDirectoryPath(DirBase::user), u8"groups.json"); + + json_t jsonGroups = json_t::array(); + for (auto& group : group_list) { - unique = false; - break; + jsonGroups.push_back(group->ToJson()); } - } - - if (unique) - { - // Check if there is already a registered player with this name - if (_userManager.GetUserByName(new_name) != nullptr) + json_t jsonGroupsCfg = { + { "default_group", default_group }, + { "groups", jsonGroups }, + }; + try { - unique = false; + Json::WriteToFile(path, jsonGroupsCfg); } - } - - if (!unique) - { - // Increment name counter - counter++; - new_name = name.substr(0, 31) + " #" + std::to_string(counter); - } - } while (!unique); - return new_name; -} - -void NetworkBase::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet) -{ - auto keyPath = NetworkGetPrivateKeyPath(Config::Get().network.PlayerName); - if (!File::Exists(keyPath)) - { - LOG_ERROR("Key file (%s) was not found. Restart client to re-generate it.", keyPath.c_str()); - return; - } - - try - { - auto fs = FileStream(keyPath, FileMode::open); - if (!_key.LoadPrivate(&fs)) - { - throw std::runtime_error("Failed to load private key."); - } - } - catch (const std::exception&) - { - LOG_ERROR("Failed to load key %s", keyPath.c_str()); - connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); - connection.Disconnect(); - return; - } - - uint32_t challenge_size; - packet >> challenge_size; - const char* challenge = reinterpret_cast(packet.Read(challenge_size)); - - std::vector signature; - const std::string pubkey = _key.PublicKeyString(); - _challenge.resize(challenge_size); - std::memcpy(_challenge.data(), challenge, challenge_size); - bool ok = _key.Sign(_challenge.data(), _challenge.size(), signature); - if (!ok) - { - LOG_ERROR("Failed to sign server's challenge."); - connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); - connection.Disconnect(); - return; - } - // Don't keep private key in memory. There's no need and it may get leaked - // when process dump gets collected at some point in future. - _key.Unload(); - - Client_Send_AUTH(Config::Get().network.PlayerName, gCustomPassword, pubkey, signature); -} - -void NetworkBase::ServerHandleRequestGamestate(NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t tick; - packet >> tick; - - if (_serverState.gamestateSnapshotsEnabled == false) - { - // Ignore this if this is off. - return; - } - - IGameStateSnapshots* snapshots = GetContext().GetGameStateSnapshots(); - - const GameStateSnapshot_t* snapshot = snapshots->GetLinkedSnapshot(tick); - if (snapshot != nullptr) - { - MemoryStream snapshotMemory; - DataSerialiser ds(true, snapshotMemory); - - snapshots->SerialiseSnapshot(const_cast(*snapshot), ds); - - uint32_t bytesSent = 0; - uint32_t length = static_cast(snapshotMemory.GetLength()); - while (bytesSent < length) - { - uint32_t dataSize = kChunkSize; - if (bytesSent + dataSize > snapshotMemory.GetLength()) + catch (const std::exception& ex) { - dataSize = snapshotMemory.GetLength() - bytesSent; - } - - NetworkPacket packetGameStateChunk(NetworkCommand::GameState); - packetGameStateChunk << tick << length << bytesSent << dataSize; - packetGameStateChunk.Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); - - connection.QueuePacket(std::move(packetGameStateChunk)); - - bytesSent += dataSize; - } - } -} - -void NetworkBase::ServerHandleHeartbeat(NetworkConnection& connection, NetworkPacket& packet) -{ - LOG_VERBOSE("Client %s heartbeat", connection.Socket->GetHostName()); - connection.ResetLastPacketTime(); -} - -void NetworkBase::Client_Handle_AUTH(NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t auth_status; - packet >> auth_status >> const_cast(player_id); - connection.AuthStatus = static_cast(auth_status); - switch (connection.AuthStatus) - { - case NetworkAuth::Ok: - Client_Send_GAMEINFO(); - break; - case NetworkAuth::BadName: - connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PLAYER_NAME); - connection.Disconnect(); - break; - case NetworkAuth::BadVersion: - { - auto version = std::string(packet.ReadString()); - auto versionp = version.c_str(); - connection.SetLastDisconnectReason(STR_MULTIPLAYER_INCORRECT_SOFTWARE_VERSION, &versionp); - connection.Disconnect(); - break; - } - case NetworkAuth::BadPassword: - connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PASSWORD); - connection.Disconnect(); - break; - case NetworkAuth::VerificationFailure: - connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); - connection.Disconnect(); - break; - case NetworkAuth::Full: - connection.SetLastDisconnectReason(STR_MULTIPLAYER_SERVER_FULL); - connection.Disconnect(); - break; - case NetworkAuth::RequirePassword: - ContextOpenWindowView(WV_NETWORK_PASSWORD); - break; - case NetworkAuth::UnknownKeyDisallowed: - connection.SetLastDisconnectReason(STR_MULTIPLAYER_UNKNOWN_KEY_DISALLOWED); - connection.Disconnect(); - break; - default: - connection.SetLastDisconnectReason(STR_MULTIPLAYER_RECEIVED_INVALID_DATA); - connection.Disconnect(); - break; - } -} - -void NetworkBase::ServerClientJoined(std::string_view name, const std::string& keyhash, NetworkConnection& connection) -{ - auto player = AddPlayer(std::string(name), keyhash); - connection.Player = player; - if (player != nullptr) - { - char text[256]; - const char* player_name = static_cast(player->Name.c_str()); - FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_JOINED_THE_GAME, &player_name); - ChatAddHistory(text); - - auto& context = GetContext(); - auto& objManager = context.GetObjectManager(); - auto objects = objManager.GetPackableObjects(); - ServerSendObjectsList(connection, objects); - ServerSendScripts(connection); - - // Log player joining event - std::string playerNameHash = player->Name + " (" + keyhash + ")"; - player_name = static_cast(playerNameHash.c_str()); - FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_JOINED_THE_GAME, &player_name); - AppendServerLog(text); - - ProcessPlayerJoinedPluginHooks(player->Id); - } -} - -void NetworkBase::ServerHandleToken(NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) -{ - uint8_t token_size = 10 + (rand() & 0x7f); - connection.Challenge.resize(token_size); - for (int32_t i = 0; i < token_size; i++) - { - connection.Challenge[i] = static_cast(rand() & 0xff); - } - ServerSendToken(connection); -} - -static void OpenNetworkProgress(StringId captionStringId) -{ - auto captionString = GetContext()->GetLocalisationService().GetString(captionStringId); - auto intent = Intent(INTENT_ACTION_PROGRESS_OPEN); - intent.PutExtra(INTENT_EXTRA_MESSAGE, captionString); - intent.PutExtra(INTENT_EXTRA_CALLBACK, []() -> void { ::GetContext()->GetNetwork().Close(); }); - ContextOpenIntent(&intent); -} - -void NetworkBase::Client_Handle_OBJECTS_LIST(NetworkConnection& connection, NetworkPacket& packet) -{ - auto& repo = GetContext().GetObjectRepository(); - - uint32_t index = 0; - uint32_t totalObjects = 0; - packet >> index >> totalObjects; - - static constexpr uint32_t kObjectStartIndex = 0; - if (index == kObjectStartIndex) - { - _missingObjects.clear(); - } - - if (totalObjects > 0) - { - OpenNetworkProgress(STR_MULTIPLAYER_RECEIVING_OBJECTS_LIST); - GetContext().SetProgress(index + 1, totalObjects); - - uint8_t objectType{}; - packet >> objectType; - - if (objectType == 0) - { - // DAT - auto entry = reinterpret_cast(packet.Read(sizeof(RCTObjectEntry))); - if (entry != nullptr) - { - const auto* object = repo.FindObject(entry); - if (object == nullptr) - { - auto objectName = std::string(entry->GetName()); - LOG_VERBOSE("Requesting object %s with checksum %x from server", objectName.c_str(), entry->checksum); - _missingObjects.push_back(ObjectEntryDescriptor(*entry)); - } - else if (object->ObjectEntry.checksum != entry->checksum || object->ObjectEntry.flags != entry->flags) - { - auto objectName = std::string(entry->GetName()); - LOG_WARNING( - "Object %s has different checksum/flags (%x/%x) than server (%x/%x).", objectName.c_str(), - object->ObjectEntry.checksum, object->ObjectEntry.flags, entry->checksum, entry->flags); - } - } - } - else - { - // JSON - auto identifier = packet.ReadString(); - if (!identifier.empty()) - { - const auto* object = repo.FindObject(identifier); - if (object == nullptr) - { - auto objectName = std::string(identifier); - LOG_VERBOSE("Requesting object %s from server", objectName.c_str()); - _missingObjects.push_back(ObjectEntryDescriptor(objectName)); - } + LOG_ERROR("Unable to save %s: %s", path.c_str(), ex.what()); } } } - if (index + 1 >= totalObjects) + void NetworkBase::SetupDefaultGroups() { - LOG_VERBOSE("client received object list, it has %u entries", totalObjects); - Client_Send_MAPREQUEST(_missingObjects); - _missingObjects.clear(); - } -} + // Admin group + auto admin = std::make_unique(); + admin->SetName("Admin"); + admin->ActionsAllowed.fill(0xFF); + admin->Id = 0; + group_list.push_back(std::move(admin)); -void NetworkBase::Client_Handle_SCRIPTS_HEADER(NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t numScripts{}; - uint32_t dataSize{}; - packet >> numScripts >> dataSize; + // Spectator group + auto spectator = std::make_unique(); + spectator->SetName("Spectator"); + spectator->ToggleActionPermission(NetworkPermission::Chat); + spectator->Id = 1; + group_list.push_back(std::move(spectator)); - #ifdef ENABLE_SCRIPTING - _serverScriptsData.data.Clear(); - _serverScriptsData.pluginCount = numScripts; - _serverScriptsData.dataSize = dataSize; - #else - if (numScripts > 0) - { - connection.SetLastDisconnectReason("The client requires plugin support."); - Close(); - } - #endif -} + // User group + auto user = std::make_unique(); + user->SetName("User"); + user->ActionsAllowed.fill(0xFF); + user->ToggleActionPermission(NetworkPermission::KickPlayer); + user->ToggleActionPermission(NetworkPermission::ModifyGroups); + user->ToggleActionPermission(NetworkPermission::SetPlayerGroup); + user->ToggleActionPermission(NetworkPermission::Cheat); + user->ToggleActionPermission(NetworkPermission::PasswordlessLogin); + user->ToggleActionPermission(NetworkPermission::ModifyTile); + user->ToggleActionPermission(NetworkPermission::EditScenarioOptions); + user->Id = 2; + group_list.push_back(std::move(user)); -void NetworkBase::Client_Handle_SCRIPTS_DATA(NetworkConnection& connection, NetworkPacket& packet) -{ - #ifdef ENABLE_SCRIPTING - uint32_t dataSize{}; - packet >> dataSize; - Guard::Assert(dataSize > 0); - - const auto* data = packet.Read(dataSize); - Guard::Assert(data != nullptr); - - auto& scriptsData = _serverScriptsData.data; - scriptsData.Write(data, dataSize); - - if (scriptsData.GetLength() == _serverScriptsData.dataSize) - { - auto& scriptEngine = GetContext().GetScriptEngine(); - - scriptsData.SetPosition(0); - for (uint32_t i = 0; i < _serverScriptsData.pluginCount; ++i) - { - const auto codeSize = scriptsData.ReadValue(); - const auto scriptData = scriptsData.ReadArray(codeSize); - - auto code = std::string_view(reinterpret_cast(scriptData.get()), codeSize); - scriptEngine.AddNetworkPlugin(code); - } - Guard::Assert(scriptsData.GetPosition() == scriptsData.GetLength()); - - // Empty the current buffer. - _serverScriptsData = {}; - } - #else - connection.SetLastDisconnectReason("The client requires plugin support."); - Close(); - #endif -} - -void NetworkBase::Client_Handle_GAMESTATE(NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t tick; - uint32_t totalSize; - uint32_t offset; - uint32_t dataSize; - - packet >> tick >> totalSize >> offset >> dataSize; - - if (offset == 0) - { - // Reset - _serverGameState = MemoryStream(); + SetDefaultGroup(1); } - _serverGameState.SetPosition(offset); - - const uint8_t* data = packet.Read(dataSize); - _serverGameState.Write(data, dataSize); - - LOG_VERBOSE( - "Received Game State %.02f%%", - (static_cast(_serverGameState.GetLength()) / static_cast(totalSize)) * 100.0f); - - if (_serverGameState.GetLength() == totalSize) + void NetworkBase::LoadGroups() { - _serverGameState.SetPosition(0); - DataSerialiser ds(false, _serverGameState); + group_list.clear(); - IGameStateSnapshots* snapshots = GetContext().GetGameStateSnapshots(); + auto& env = GetContext().GetPlatformEnvironment(); + auto path = Path::Combine(env.GetDirectoryPath(DirBase::user), u8"groups.json"); - GameStateSnapshot_t& serverSnapshot = snapshots->CreateSnapshot(); - snapshots->SerialiseSnapshot(serverSnapshot, ds); - - const GameStateSnapshot_t* desyncSnapshot = snapshots->GetLinkedSnapshot(tick); - if (desyncSnapshot != nullptr) - { - GameStateCompareData cmpData = snapshots->Compare(serverSnapshot, *desyncSnapshot); - - std::string outputPath = GetContext().GetPlatformEnvironment().GetDirectoryPath(DirBase::user, DirId::desyncLogs); - - Path::CreateDirectory(outputPath); - - char uniqueFileName[128] = {}; - snprintf( - uniqueFileName, sizeof(uniqueFileName), "desync_%llu_%u.txt", - static_cast(Platform::GetDatetimeNowUTC()), tick); - - std::string outputFile = Path::Combine(outputPath, uniqueFileName); - - if (snapshots->LogCompareDataToFile(outputFile, cmpData)) - { - LOG_INFO("Wrote desync report to '%s'", outputFile.c_str()); - - auto ft = Formatter(); - ft.Add(uniqueFileName); - - char str_desync[1024]; - FormatStringLegacy(str_desync, sizeof(str_desync), STR_DESYNC_REPORT, ft.Data()); - - auto intent = Intent(WindowClass::NetworkStatus); - intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_desync }); - ContextOpenIntent(&intent); - } - } - } -} - -void NetworkBase::ServerHandleMapRequest(NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t size; - packet >> size; - LOG_VERBOSE("Client requested %u objects", size); - auto& repo = GetContext().GetObjectRepository(); - for (uint32_t i = 0; i < size; i++) - { - uint8_t generation{}; - packet >> generation; - - std::string objectName; - const ObjectRepositoryItem* item{}; - if (generation == static_cast(ObjectGeneration::DAT)) - { - const auto* entry = reinterpret_cast(packet.Read(sizeof(RCTObjectEntry))); - objectName = std::string(entry->GetName()); - LOG_VERBOSE("Client requested object %s", objectName.c_str()); - item = repo.FindObject(entry); - } - else - { - objectName = std::string(packet.ReadString()); - LOG_VERBOSE("Client requested object %s", objectName.c_str()); - item = repo.FindObject(objectName); - } - - if (item == nullptr) - { - LOG_WARNING("Client tried getting non-existent object %s from us.", objectName.c_str()); - } - else - { - connection.RequestedObjects.push_back(item); - } - } - - auto player_name = connection.Player->Name.c_str(); - ServerSendMap(&connection); - ServerSendEventPlayerJoined(player_name); - ServerSendGroupList(connection); -} - -void NetworkBase::ServerHandleAuth(NetworkConnection& connection, NetworkPacket& packet) -{ - if (connection.AuthStatus != NetworkAuth::Ok) - { - auto* hostName = connection.Socket->GetHostName(); - auto gameversion = packet.ReadString(); - auto name = packet.ReadString(); - auto password = packet.ReadString(); - auto pubkey = packet.ReadString(); - uint32_t sigsize; - packet >> sigsize; - if (pubkey.empty()) - { - connection.AuthStatus = NetworkAuth::VerificationFailure; - } - else + json_t jsonGroupConfig; + if (File::Exists(path)) { try { - // RSA technically supports keys up to 65536 bits, so this is the - // maximum signature size for now. - constexpr auto MaxRSASignatureSizeInBytes = 8192; - - if (sigsize == 0 || sigsize > MaxRSASignatureSizeInBytes) - { - throw std::runtime_error("Invalid signature size"); - } - - std::vector signature; - signature.resize(sigsize); - - const uint8_t* signatureData = packet.Read(sigsize); - if (signatureData == nullptr) - { - throw std::runtime_error("Failed to read packet."); - } - - std::memcpy(signature.data(), signatureData, sigsize); - - auto ms = MemoryStream(pubkey.data(), pubkey.size()); - if (!connection.Key.LoadPublic(&ms)) - { - throw std::runtime_error("Failed to load public key."); - } - - bool verified = connection.Key.Verify(connection.Challenge.data(), connection.Challenge.size(), signature); - const std::string hash = connection.Key.PublicKeyHash(); - if (verified) - { - LOG_VERBOSE("Connection %s: Signature verification ok. Hash %s", hostName, hash.c_str()); - if (Config::Get().network.KnownKeysOnly && _userManager.GetUserByHash(hash) == nullptr) - { - LOG_VERBOSE("Connection %s: Hash %s, not known", hostName, hash.c_str()); - connection.AuthStatus = NetworkAuth::UnknownKeyDisallowed; - } - else - { - connection.AuthStatus = NetworkAuth::Verified; - } - } - else - { - connection.AuthStatus = NetworkAuth::VerificationFailure; - LOG_VERBOSE("Connection %s: Signature verification failed!", hostName); - } + jsonGroupConfig = Json::ReadFromFile(path); } - catch (const std::exception&) + catch (const std::exception& e) { - connection.AuthStatus = NetworkAuth::VerificationFailure; - LOG_VERBOSE("Connection %s: Signature verification failed, invalid data!", hostName); + LOG_ERROR("Failed to read %s as JSON. Setting default groups. %s", path.c_str(), e.what()); } } - bool passwordless = false; - if (connection.AuthStatus == NetworkAuth::Verified) + if (!jsonGroupConfig.is_object()) { - const NetworkGroup* group = GetGroupByID(GetGroupIDByHash(connection.Key.PublicKeyHash())); - if (group != nullptr) - { - passwordless = group->CanPerformAction(NetworkPermission::PasswordlessLogin); - } - } - if (gameversion != NetworkGetVersion()) - { - connection.AuthStatus = NetworkAuth::BadVersion; - LOG_INFO("Connection %s: Bad version.", hostName); - } - else if (name.empty()) - { - connection.AuthStatus = NetworkAuth::BadName; - LOG_INFO("Connection %s: Bad name.", connection.Socket->GetHostName()); - } - else if (!passwordless) - { - if (password.empty() && !_password.empty()) - { - connection.AuthStatus = NetworkAuth::RequirePassword; - LOG_INFO("Connection %s: Requires password.", hostName); - } - else if (!password.empty() && _password != password) - { - connection.AuthStatus = NetworkAuth::BadPassword; - LOG_INFO("Connection %s: Bad password.", hostName); - } - } - - if (GetNumVisiblePlayers() >= Config::Get().network.Maxplayers) - { - connection.AuthStatus = NetworkAuth::Full; - LOG_INFO("Connection %s: Server is full.", hostName); - } - else if (connection.AuthStatus == NetworkAuth::Verified) - { - const std::string hash = connection.Key.PublicKeyHash(); - if (ProcessPlayerAuthenticatePluginHooks(connection, name, hash)) - { - connection.AuthStatus = NetworkAuth::Ok; - ServerClientJoined(name, hash, connection); - } - else - { - connection.AuthStatus = NetworkAuth::VerificationFailure; - LOG_INFO("Connection %s: Denied by plugin.", hostName); - } - } - - ServerSendAuth(connection); - } -} - -void NetworkBase::Client_Handle_MAP([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t size, offset; - packet >> size >> offset; - int32_t chunksize = static_cast(packet.Header.Size - packet.BytesRead); - if (chunksize <= 0) - { - return; - } - if (offset == 0) - { - // Start of a new map load, clear the queue now as we have to buffer them - // until the map is fully loaded. - GameActions::ClearQueue(); - GameActions::SuspendQueue(); - - _serverTickData.clear(); - _clientMapLoaded = false; - - OpenNetworkProgress(STR_MULTIPLAYER_DOWNLOADING_MAP); - } - if (size > chunk_buffer.size()) - { - chunk_buffer.resize(size); - } - - const auto currentProgressKiB = (offset + chunksize) / 1024; - const auto totalSizeKiB = size / 1024; - - GetContext().SetProgress(currentProgressKiB, totalSizeKiB, STR_STRING_M_OF_N_KIB); - - std::memcpy(&chunk_buffer[offset], const_cast(static_cast(packet.Read(chunksize))), chunksize); - if (offset + chunksize == size) - { - // Allow queue processing of game actions again. - GameActions::ResumeQueue(); - - ContextForceCloseWindowByClass(WindowClass::ProgressWindow); - GameUnloadScripts(); - GameNotifyMapChange(); - - bool has_to_free = false; - uint8_t* data = &chunk_buffer[0]; - size_t data_size = size; - auto ms = MemoryStream(data, data_size); - if (LoadMap(&ms)) - { - GameLoadInit(); - GameLoadScripts(); - GameNotifyMapChanged(); - _serverState.tick = getGameState().currentTicks; - // NetworkStatusOpen("Loaded new map from network"); - _serverState.state = NetworkServerStatus::Ok; - _clientMapLoaded = true; - gFirstTimeSaving = true; - - // Notify user he is now online and which shortcut key enables chat - NetworkChatShowConnectedMessage(); - - // Fix invalid vehicle sprite sizes, thus preventing visual corruption of sprites - FixInvalidVehicleSpriteSizes(); - - // NOTE: Game actions are normally processed before processing the player list. - // Given that during map load game actions are buffered we have to process the - // player list first to have valid players for the queued game actions. - ProcessPlayerList(); + SetupDefaultGroups(); } else { - // Something went wrong, game is not loaded. Return to main screen. - auto loadOrQuitAction = GameActions::LoadOrQuitAction( - GameActions::LoadOrQuitModes::OpenSavePrompt, PromptMode::saveBeforeQuit); - GameActions::Execute(&loadOrQuitAction); - } - if (has_to_free) - { - free(data); - } - } -} - -bool NetworkBase::LoadMap(IStream* stream) -{ - bool result = false; - try - { - auto& context = GetContext(); - auto& objManager = context.GetObjectManager(); - auto importer = ParkImporter::CreateParkFile(context.GetObjectRepository()); - auto loadResult = importer->LoadFromStream(stream, false); - objManager.LoadObjects(loadResult.RequiredObjects); - - MapAnimations::ClearAll(); - // TODO: Have a separate GameState and exchange once loaded. - auto& gameState = getGameState(); - importer->Import(gameState); - - EntityTweener::Get().Reset(); - MapAnimations::MarkAllTiles(); - - gLastAutoSaveUpdate = kAutosavePause; - result = true; - } - catch (const std::exception& e) - { - Console::Error::WriteLine("Unable to read map from server: %s", e.what()); - } - return result; -} - -bool NetworkBase::SaveMap(IStream* stream, const std::vector& objects) const -{ - bool result = false; - PrepareMapForSave(); - try - { - auto exporter = std::make_unique(); - exporter->ExportObjectsList = objects; - - auto& gameState = getGameState(); - exporter->Export(gameState, *stream, kParkFileNetCompressionLevel); - result = true; - } - catch (const std::exception& e) - { - Console::Error::WriteLine("Unable to serialise map: %s", e.what()); - } - return result; -} - -void NetworkBase::Client_Handle_CHAT([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - auto text = packet.ReadString(); - if (!text.empty()) - { - ChatAddHistory(std::string(text)); - } -} - -static bool ProcessChatMessagePluginHooks(uint8_t playerId, std::string& text) -{ - #ifdef ENABLE_SCRIPTING - auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); - if (hookEngine.HasSubscriptions(Scripting::HookType::networkChat)) - { - auto ctx = GetContext()->GetScriptEngine().GetContext(); - - // Create event args object - auto objIdx = duk_push_object(ctx); - duk_push_number(ctx, playerId); - duk_put_prop_string(ctx, objIdx, "player"); - duk_push_string(ctx, text.c_str()); - duk_put_prop_string(ctx, objIdx, "message"); - auto e = DukValue::take_from_stack(ctx); - - // Call the subscriptions - hookEngine.Call(Scripting::HookType::networkChat, e, false); - - // Update text from object if subscriptions changed it - if (e["message"].type() != DukValue::Type::STRING) - { - // Subscription set text to non-string, do not relay message - return false; - } - text = e["message"].as_string(); - if (text.empty()) - { - // Subscription set text to empty string, do not relay message - return false; - } - } - #endif - return true; -} - -void NetworkBase::ServerHandleChat(NetworkConnection& connection, NetworkPacket& packet) -{ - auto szText = packet.ReadString(); - if (szText.empty()) - return; - - if (connection.Player != nullptr) - { - NetworkGroup* group = GetGroupByID(connection.Player->Group); - if (group == nullptr || !group->CanPerformAction(NetworkPermission::Chat)) - { - return; - } - } - - std::string text(szText); - if (connection.Player != nullptr) - { - if (!ProcessChatMessagePluginHooks(connection.Player->Id, text)) - { - // Message not to be relayed - return; - } - } - - const char* formatted = FormatChat(connection.Player, text.c_str()); - ChatAddHistory(formatted); - ServerSendChat(formatted); -} - -void NetworkBase::Client_Handle_GAME_ACTION([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t tick; - GameCommand actionType; - packet >> tick >> actionType; - - MemoryStream stream; - const size_t size = packet.Header.Size - packet.BytesRead; - stream.WriteArray(packet.Read(size), size); - stream.SetPosition(0); - - DataSerialiser ds(false, stream); - - GameActions::GameAction::Ptr action = GameActions::Create(actionType); - if (action == nullptr) - { - LOG_ERROR("Received unregistered game action type: 0x%08X", actionType); - return; - } - action->Serialise(ds); - - if (player_id == action->GetPlayer().id) - { - // Only execute callbacks that belong to us, - // clients can have identical network ids assigned. - auto itr = _gameActionCallbacks.find(action->GetNetworkId()); - if (itr != _gameActionCallbacks.end()) - { - action->SetCallback(itr->second); - _gameActionCallbacks.erase(itr); - } - } - - GameActions::Enqueue(std::move(action), tick); -} - -void NetworkBase::ServerHandleGameAction(NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t tick; - GameCommand actionType; - - NetworkPlayer* player = connection.Player; - if (player == nullptr) - { - return; - } - - packet >> tick >> actionType; - - // Don't let clients send pause or quit - if (actionType == GameCommand::TogglePause || actionType == GameCommand::LoadOrQuit) - { - return; - } - - if (actionType != GameCommand::Custom) - { - // Check if player's group permission allows command to run - NetworkGroup* group = GetGroupByID(connection.Player->Group); - if (group == nullptr || group->CanPerformCommand(actionType) == false) - { - ServerSendShowError(connection, STR_CANT_DO_THIS, STR_PERMISSION_DENIED); - return; - } - } - - // Create and enqueue the action. - GameActions::GameAction::Ptr ga = GameActions::Create(actionType); - if (ga == nullptr) - { - LOG_ERROR( - "Received unregistered game action type: 0x%08X from player: (%d) %s", actionType, connection.Player->Id, - connection.Player->Name.c_str()); - return; - } - - // Player who is hosting is not affected by cooldowns. - if ((player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) == 0) - { - auto cooldownIt = player->CooldownTime.find(actionType); - if (cooldownIt != std::end(player->CooldownTime)) - { - if (cooldownIt->second > 0) + json_t jsonGroups = jsonGroupConfig["groups"]; + if (jsonGroups.is_array()) { - ServerSendShowError(connection, STR_CANT_DO_THIS, STR_NETWORK_ACTION_RATE_LIMIT_MESSAGE); - return; + for (auto& jsonGroup : jsonGroups) + { + group_list.emplace_back(std::make_unique(NetworkGroup::FromJson(jsonGroup))); + } + } + + default_group = Json::GetNumber(jsonGroupConfig["default_group"]); + if (GetGroupByID(default_group) == nullptr) + { + default_group = 0; } } - uint32_t cooldownTime = ga->GetCooldownTime(); - if (cooldownTime > 0) + // Host group should always contain all permissions. + group_list.at(0)->ActionsAllowed.fill(0xFF); + } + + std::string NetworkBase::BeginLog( + const std::string& directory, const std::string& midName, const std::string& filenameFormat) + { + utf8 filename[256]; + time_t timer; + time(&timer); + auto tmInfo = localtime(&timer); + if (strftime(filename, sizeof(filename), filenameFormat.c_str(), tmInfo) == 0) { - player->CooldownTime[actionType] = cooldownTime; + throw std::runtime_error("strftime failed"); } + + auto directoryMidName = Path::Combine(directory, midName); + Path::CreateDirectory(directoryMidName); + return Path::Combine(directoryMidName, filename); } - DataSerialiser stream(false); - const size_t size = packet.Header.Size - packet.BytesRead; - stream.GetStream().WriteArray(packet.Read(size), size); - stream.GetStream().SetPosition(0); - - ga->Serialise(stream); - // Set player to sender, should be 0 if sent from client. - ga->SetPlayer(NetworkPlayerId_t{ connection.Player->Id }); - - GameActions::Enqueue(std::move(ga), tick); -} - -void NetworkBase::Client_Handle_TICK([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t srand0; - uint32_t flags; - uint32_t serverTick; - - packet >> serverTick >> srand0 >> flags; - - ServerTickData tickData; - tickData.srand0 = srand0; - tickData.tick = serverTick; - - if (flags & NETWORK_TICK_FLAG_CHECKSUMS) + void NetworkBase::AppendLog(std::ostream& fs, std::string_view s) { - auto text = packet.ReadString(); - if (!text.empty()) + if (fs.fail()) { - tickData.spriteHash = text; + LOG_ERROR("bad ostream failed to append log"); + return; } - } - - // Don't let the history grow too much. - while (_serverTickData.size() >= 100) - { - _serverTickData.erase(_serverTickData.begin()); - } - - _serverState.tick = serverTick; - _serverTickData.emplace(serverTick, tickData); -} - -void NetworkBase::Client_Handle_PLAYERINFO([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t tick; - packet >> tick; - - NetworkPlayer playerInfo; - playerInfo.Read(packet); - - _pendingPlayerInfo.emplace(tick, playerInfo); -} - -void NetworkBase::Client_Handle_PLAYERLIST([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint32_t tick; - uint8_t size; - packet >> tick >> size; - - auto& pending = _pendingPlayerLists[tick]; - pending.players.clear(); - - for (uint32_t i = 0; i < size; i++) - { - NetworkPlayer tempplayer; - tempplayer.Read(packet); - - pending.players.push_back(std::move(tempplayer)); - } -} - -void NetworkBase::Client_Handle_PING([[maybe_unused]] NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) -{ - Client_Send_PING(); -} - -void NetworkBase::ServerHandlePing(NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) -{ - int32_t ping = Platform::GetTicks() - connection.PingTime; - if (ping < 0) - { - ping = 0; - } - if (connection.Player != nullptr) - { - connection.Player->Ping = ping; - auto* windowMgr = Ui::GetWindowManager(); - windowMgr->InvalidateByNumber(WindowClass::Player, connection.Player->Id); - } -} - -void NetworkBase::Client_Handle_PINGLIST([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint8_t size; - packet >> size; - for (uint32_t i = 0; i < size; i++) - { - uint8_t id; - uint16_t ping; - packet >> id >> ping; - NetworkPlayer* player = GetPlayerByID(id); - if (player != nullptr) + try { - player->Ping = ping; - } - } - - auto* windowMgr = Ui::GetWindowManager(); - windowMgr->InvalidateByClass(WindowClass::Player); -} - -void NetworkBase::Client_Handle_SETDISCONNECTMSG(NetworkConnection& connection, NetworkPacket& packet) -{ - auto disconnectmsg = packet.ReadString(); - if (!disconnectmsg.empty()) - { - connection.SetLastDisconnectReason(disconnectmsg); - } -} - -void NetworkBase::ServerHandleGameInfo(NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) -{ - ServerSendGameInfo(connection); -} - -void NetworkBase::Client_Handle_SHOWERROR([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - StringId title, message; - packet >> title >> message; - ContextShowError(title, message, {}); -} - -void NetworkBase::Client_Handle_GROUPLIST([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - group_list.clear(); - uint8_t size; - packet >> size >> default_group; - for (uint32_t i = 0; i < size; i++) - { - NetworkGroup group; - group.Read(packet); - auto newgroup = std::make_unique(group); - group_list.push_back(std::move(newgroup)); - } -} - -void NetworkBase::Client_Handle_EVENT([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - uint16_t eventType; - packet >> eventType; - switch (eventType) - { - case SERVER_EVENT_PLAYER_JOINED: - { - auto playerName = packet.ReadString(); - auto message = FormatStringID(STR_MULTIPLAYER_PLAYER_HAS_JOINED_THE_GAME, playerName); - ChatAddHistory(message); - break; - } - case SERVER_EVENT_PLAYER_DISCONNECTED: - { - auto playerName = packet.ReadString(); - auto reason = packet.ReadString(); - std::string message; - if (reason.empty()) + utf8 buffer[1024]; + time_t timer; + time(&timer); + auto tmInfo = localtime(&timer); + if (strftime(buffer, sizeof(buffer), "[%Y/%m/%d %H:%M:%S] ", tmInfo) != 0) { - message = FormatStringID(STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_NO_REASON, playerName); + String::append(buffer, sizeof(buffer), std::string(s).c_str()); + String::append(buffer, sizeof(buffer), PLATFORM_NEWLINE); + + fs.write(buffer, strlen(buffer)); + } + } + catch (const std::exception& ex) + { + LOG_ERROR("%s", ex.what()); + } + } + + void NetworkBase::BeginChatLog() + { + auto& env = GetContext().GetPlatformEnvironment(); + auto directory = env.GetDirectoryPath(DirBase::user, DirId::chatLogs); + _chatLogPath = BeginLog(directory, "", _chatLogFilenameFormat); + _chat_log_fs.open(fs::u8path(_chatLogPath), std::ios::out | std::ios::app); + } + + void NetworkBase::AppendChatLog(std::string_view s) + { + if (Config::Get().network.LogChat && _chat_log_fs.is_open()) + { + AppendLog(_chat_log_fs, s); + } + } + + void NetworkBase::CloseChatLog() + { + _chat_log_fs.close(); + } + + void NetworkBase::BeginServerLog() + { + auto& env = GetContext().GetPlatformEnvironment(); + auto directory = env.GetDirectoryPath(DirBase::user, DirId::serverLogs); + _serverLogPath = BeginLog(directory, ServerName, _serverLogFilenameFormat); + _server_log_fs.open(fs::u8path(_serverLogPath), std::ios::out | std::ios::app | std::ios::binary); + + // Log server start event + utf8 logMessage[256]; + if (GetMode() == NETWORK_MODE_CLIENT) + { + FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_CLIENT_STARTED, nullptr); + } + else if (GetMode() == NETWORK_MODE_SERVER) + { + FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_SERVER_STARTED, nullptr); + } + else + { + logMessage[0] = '\0'; + Guard::Assert(false, "Unknown network mode!"); + } + AppendServerLog(logMessage); + } + + void NetworkBase::AppendServerLog(const std::string& s) + { + if (Config::Get().network.LogServerActions && _server_log_fs.is_open()) + { + AppendLog(_server_log_fs, s); + } + } + + void NetworkBase::CloseServerLog() + { + // Log server stopped event + char logMessage[256]; + if (GetMode() == NETWORK_MODE_CLIENT) + { + FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_CLIENT_STOPPED, nullptr); + } + else if (GetMode() == NETWORK_MODE_SERVER) + { + FormatStringLegacy(logMessage, sizeof(logMessage), STR_LOG_SERVER_STOPPED, nullptr); + } + else + { + logMessage[0] = '\0'; + Guard::Assert(false, "Unknown network mode!"); + } + AppendServerLog(logMessage); + _server_log_fs.close(); + } + + void NetworkBase::Client_Send_RequestGameState(uint32_t tick) + { + if (_serverState.gamestateSnapshotsEnabled == false) + { + LOG_VERBOSE("Server does not store a gamestate history"); + return; + } + + LOG_VERBOSE("Requesting gamestate from server for tick %u", tick); + + NetworkPacket packet(NetworkCommand::RequestGameState); + packet << tick; + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::Client_Send_TOKEN() + { + LOG_VERBOSE("requesting token"); + NetworkPacket packet(NetworkCommand::Token); + _serverConnection->AuthStatus = NetworkAuth::Requested; + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::Client_Send_AUTH( + const std::string& name, const std::string& password, const std::string& pubkey, const std::vector& signature) + { + NetworkPacket packet(NetworkCommand::Auth); + packet.WriteString(NetworkGetVersion()); + packet.WriteString(name); + packet.WriteString(password); + packet.WriteString(pubkey); + assert(signature.size() <= static_cast(UINT32_MAX)); + packet << static_cast(signature.size()); + packet.Write(signature.data(), signature.size()); + _serverConnection->AuthStatus = NetworkAuth::Requested; + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::Client_Send_MAPREQUEST(const std::vector& objects) + { + LOG_VERBOSE("client requests %u objects", uint32_t(objects.size())); + NetworkPacket packet(NetworkCommand::MapRequest); + packet << static_cast(objects.size()); + for (const auto& object : objects) + { + std::string name(object.GetName()); + LOG_VERBOSE("client requests object %s", name.c_str()); + if (object.Generation == ObjectGeneration::DAT) + { + packet << static_cast(0); + packet.Write(&object.Entry, sizeof(RCTObjectEntry)); } else { - message = FormatStringID(STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_WITH_REASON, playerName, reason); + packet << static_cast(1); + packet.WriteString(name); } - ChatAddHistory(message); - break; } + _serverConnection->QueuePacket(std::move(packet)); } -} -void NetworkBase::Client_Send_GAMEINFO() -{ - LOG_VERBOSE("requesting gameinfo"); - NetworkPacket packet(NetworkCommand::GameInfo); - _serverConnection->QueuePacket(std::move(packet)); -} - -void NetworkBase::Client_Handle_GAMEINFO([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) -{ - auto jsonString = packet.ReadString(); - packet >> _serverState.gamestateSnapshotsEnabled; - packet >> IsServerPlayerInvisible; - - json_t jsonData = Json::FromString(jsonString); - - if (jsonData.is_object()) + void NetworkBase::ServerSendToken(NetworkConnection& connection) { - ServerName = Json::GetString(jsonData["name"]); - ServerDescription = Json::GetString(jsonData["description"]); - ServerGreeting = Json::GetString(jsonData["greeting"]); + NetworkPacket packet(NetworkCommand::Token); + packet << static_cast(connection.Challenge.size()); + packet.Write(connection.Challenge.data(), connection.Challenge.size()); + connection.QueuePacket(std::move(packet)); + } - json_t jsonProvider = jsonData["provider"]; - if (jsonProvider.is_object()) + void NetworkBase::ServerSendObjectsList( + NetworkConnection& connection, const std::vector& objects) const + { + LOG_VERBOSE("Server sends objects list with %u items", objects.size()); + + if (objects.empty()) { - ServerProviderName = Json::GetString(jsonProvider["name"]); - ServerProviderEmail = Json::GetString(jsonProvider["email"]); - ServerProviderWebsite = Json::GetString(jsonProvider["website"]); + NetworkPacket packet(NetworkCommand::ObjectsList); + packet << static_cast(0) << static_cast(objects.size()); + + connection.QueuePacket(std::move(packet)); } - } - - NetworkChatShowServerGreeting(); -} - -void NetworkReconnect() -{ - GetContext()->GetNetwork().Reconnect(); -} - -void NetworkShutdownClient() -{ - GetContext()->GetNetwork().ServerClientDisconnected(); -} - -int32_t NetworkBeginClient(const std::string& host, int32_t port) -{ - return GetContext()->GetNetwork().BeginClient(host, port); -} - -int32_t NetworkBeginServer(int32_t port, const std::string& address) -{ - return GetContext()->GetNetwork().BeginServer(port, address); -} - -void NetworkUpdate() -{ - GetContext()->GetNetwork().Update(); -} - -void NetworkProcessPending() -{ - GetContext()->GetNetwork().ProcessPending(); -} - -void NetworkFlush() -{ - GetContext()->GetNetwork().Flush(); -} - -int32_t NetworkGetMode() -{ - return GetContext()->GetNetwork().GetMode(); -} - -int32_t NetworkGetStatus() -{ - return GetContext()->GetNetwork().GetStatus(); -} - -bool NetworkIsDesynchronised() -{ - return GetContext()->GetNetwork().IsDesynchronised(); -} - -bool NetworkCheckDesynchronisation() -{ - return GetContext()->GetNetwork().CheckDesynchronizaton(); -} - -void NetworkRequestGamestateSnapshot() -{ - return GetContext()->GetNetwork().RequestStateSnapshot(); -} - -void NetworkSendTick() -{ - GetContext()->GetNetwork().ServerSendTick(); -} - -NetworkAuth NetworkGetAuthstatus() -{ - return GetContext()->GetNetwork().GetAuthStatus(); -} - -uint32_t NetworkGetServerTick() -{ - return GetContext()->GetNetwork().GetServerTick(); -} - -uint8_t NetworkGetCurrentPlayerId() -{ - return GetContext()->GetNetwork().GetPlayerID(); -} - -int32_t NetworkGetNumPlayers() -{ - return GetContext()->GetNetwork().GetTotalNumPlayers(); -} - -int32_t NetworkGetNumVisiblePlayers() -{ - return GetContext()->GetNetwork().GetNumVisiblePlayers(); -} - -const char* NetworkGetPlayerName(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - return static_cast(network.player_list[index]->Name.c_str()); -} - -uint32_t NetworkGetPlayerFlags(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - return network.player_list[index]->Flags; -} - -int32_t NetworkGetPlayerPing(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - return network.player_list[index]->Ping; -} - -int32_t NetworkGetPlayerID(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - return network.player_list[index]->Id; -} - -money64 NetworkGetPlayerMoneySpent(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - return network.player_list[index]->MoneySpent; -} - -std::string NetworkGetPlayerIPAddress(uint32_t id) -{ - auto& network = GetContext()->GetNetwork(); - auto conn = network.GetPlayerConnection(id); - if (conn != nullptr && conn->Socket != nullptr) - { - return conn->Socket->GetIpAddress(); - } - return {}; -} - -std::string NetworkGetPlayerPublicKeyHash(uint32_t id) -{ - auto& network = GetContext()->GetNetwork(); - auto player = network.GetPlayerByID(id); - if (player != nullptr) - { - return player->KeyHash; - } - return {}; -} - -void NetworkIncrementPlayerNumCommands(uint32_t playerIndex) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(playerIndex, network.player_list); - - network.player_list[playerIndex]->IncrementNumCommands(); -} - -void NetworkAddPlayerMoneySpent(uint32_t index, money64 cost) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - network.player_list[index]->AddMoneySpent(cost); -} - -int32_t NetworkGetPlayerLastAction(uint32_t index, int32_t time) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - if (time && Platform::GetTicks() > network.player_list[index]->LastActionTime + time) - { - return -999; - } - return network.player_list[index]->LastAction; -} - -void NetworkSetPlayerLastAction(uint32_t index, GameCommand command) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - network.player_list[index]->LastAction = static_cast(NetworkActions::FindCommand(command)); - network.player_list[index]->LastActionTime = Platform::GetTicks(); -} - -CoordsXYZ NetworkGetPlayerLastActionCoord(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, GetContext()->GetNetwork().player_list); - - return network.player_list[index]->LastActionCoord; -} - -void NetworkSetPlayerLastActionCoord(uint32_t index, const CoordsXYZ& coord) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - if (index < network.player_list.size()) - { - network.player_list[index]->LastActionCoord = coord; - } -} - -uint32_t NetworkGetPlayerCommandsRan(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, GetContext()->GetNetwork().player_list); - - return network.player_list[index]->CommandsRan; -} - -int32_t NetworkGetPlayerIndex(uint32_t id) -{ - auto& network = GetContext()->GetNetwork(); - auto it = network.GetPlayerIteratorByID(id); - if (it == network.player_list.end()) - { - return -1; - } - return static_cast(network.GetPlayerIteratorByID(id) - network.player_list.begin()); -} - -uint8_t NetworkGetPlayerGroup(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - - return network.player_list[index]->Group; -} - -void NetworkSetPlayerGroup(uint32_t index, uint32_t groupindex) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.player_list); - Guard::IndexInRange(groupindex, network.group_list); - - network.player_list[index]->Group = network.group_list[groupindex]->Id; -} - -int32_t NetworkGetGroupIndex(uint8_t id) -{ - auto& network = GetContext()->GetNetwork(); - auto it = network.GetGroupIteratorByID(id); - if (it == network.group_list.end()) - { - return -1; - } - return static_cast(network.GetGroupIteratorByID(id) - network.group_list.begin()); -} - -uint8_t NetworkGetGroupID(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(index, network.group_list); - - return network.group_list[index]->Id; -} - -int32_t NetworkGetNumGroups() -{ - auto& network = GetContext()->GetNetwork(); - return static_cast(network.group_list.size()); -} - -const char* NetworkGetGroupName(uint32_t index) -{ - auto& network = GetContext()->GetNetwork(); - return network.group_list[index]->GetName().c_str(); -} - -void NetworkChatShowConnectedMessage() -{ - auto windowManager = Ui::GetWindowManager(); - std::string s = windowManager->GetKeyboardShortcutString("interface.misc.multiplayer_chat"); - const char* sptr = s.c_str(); - - utf8 buffer[256]; - FormatStringLegacy(buffer, sizeof(buffer), STR_MULTIPLAYER_CONNECTED_CHAT_HINT, &sptr); - - NetworkPlayer server; - server.Name = "Server"; - const char* formatted = NetworkBase::FormatChat(&server, buffer); - ChatAddHistory(formatted); -} - -// Display server greeting if one exists -void NetworkChatShowServerGreeting() -{ - const auto& greeting = NetworkGetServerGreeting(); - if (!greeting.empty()) - { - thread_local std::string greeting_formatted; - greeting_formatted.assign("{OUTLINE}{GREEN}"); - greeting_formatted += greeting; - ChatAddHistory(greeting_formatted); - } -} - -GameActions::Result NetworkSetPlayerGroup( - NetworkPlayerId_t actionPlayerId, NetworkPlayerId_t playerId, uint8_t groupId, bool isExecuting) -{ - auto& network = GetContext()->GetNetwork(); - NetworkPlayer* player = network.GetPlayerByID(playerId); - - NetworkGroup* fromgroup = network.GetGroupByID(actionPlayerId); - if (player == nullptr) - { - return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_DO_THIS, kStringIdNone); - } - - if (network.GetGroupByID(groupId) == nullptr) - { - return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_DO_THIS, kStringIdNone); - } - - if (player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) - { - return GameActions::Result( - GameActions::Status::InvalidParameters, STR_CANT_CHANGE_GROUP_THAT_THE_HOST_BELONGS_TO, kStringIdNone); - } - - if (groupId == 0 && fromgroup != nullptr && fromgroup->Id != 0) - { - return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_SET_TO_THIS_GROUP, kStringIdNone); - } - - if (isExecuting) - { - player->Group = groupId; - - if (NetworkGetMode() == NETWORK_MODE_SERVER) + else { - // Add or update saved user - NetworkUserManager& userManager = network._userManager; - NetworkUser* networkUser = userManager.GetOrAddUser(player->KeyHash); - networkUser->GroupId = groupId; - networkUser->Name = player->Name; - userManager.Save(); + for (size_t i = 0; i < objects.size(); ++i) + { + const auto* object = objects[i]; + + NetworkPacket packet(NetworkCommand::ObjectsList); + packet << static_cast(i) << static_cast(objects.size()); + + if (object->Identifier.empty()) + { + // DAT + LOG_VERBOSE("Object %.8s (checksum %x)", object->ObjectEntry.name, object->ObjectEntry.checksum); + packet << static_cast(0); + packet.Write(&object->ObjectEntry, sizeof(RCTObjectEntry)); + } + else + { + // JSON + LOG_VERBOSE("Object %s", object->Identifier.c_str()); + packet << static_cast(1); + packet.WriteString(object->Identifier); + } + + connection.QueuePacket(std::move(packet)); + } + } + } + + void NetworkBase::ServerSendScripts(NetworkConnection& connection) + { + #ifdef ENABLE_SCRIPTING + using namespace OpenRCT2::Scripting; + + auto& scriptEngine = GetContext().GetScriptEngine(); + + // Get remote plugin list. + const auto remotePlugins = scriptEngine.GetRemotePlugins(); + LOG_VERBOSE("Server sends %zu scripts", remotePlugins.size()); + + // Build the data contents for each plugin. + MemoryStream pluginData; + for (auto& plugin : remotePlugins) + { + const auto& code = plugin->GetCode(); + + const auto codeSize = static_cast(code.size()); + pluginData.WriteValue(codeSize); + pluginData.WriteArray(code.c_str(), code.size()); } - auto* windowMgr = Ui::GetWindowManager(); - windowMgr->InvalidateByNumber(WindowClass::Player, playerId); + // Send the header packet. + NetworkPacket packetScriptHeader(NetworkCommand::ScriptsHeader); + packetScriptHeader << static_cast(remotePlugins.size()); + packetScriptHeader << static_cast(pluginData.GetLength()); + connection.QueuePacket(std::move(packetScriptHeader)); - // Log set player group event - NetworkPlayer* game_command_player = network.GetPlayerByID(actionPlayerId); - NetworkGroup* new_player_group = network.GetGroupByID(groupId); - char log_msg[256]; - const char* args[3] = { - player->Name.c_str(), - new_player_group->GetName().c_str(), - game_command_player->Name.c_str(), + // Segment the plugin data into chunks and send them. + const uint8_t* pluginDataBuffer = static_cast(pluginData.GetData()); + uint32_t dataOffset = 0; + while (dataOffset < pluginData.GetLength()) + { + const uint32_t chunkSize = std::min(pluginData.GetLength() - dataOffset, kChunkSize); + + NetworkPacket packet(NetworkCommand::ScriptsData); + packet << chunkSize; + packet.Write(pluginDataBuffer + dataOffset, chunkSize); + + connection.QueuePacket(std::move(packet)); + + dataOffset += chunkSize; + } + Guard::Assert(dataOffset == pluginData.GetLength()); + + #else + NetworkPacket packetScriptHeader(NetworkCommand::ScriptsHeader); + packetScriptHeader << static_cast(0u); + packetScriptHeader << static_cast(0u); + #endif + } + + void NetworkBase::Client_Send_HEARTBEAT(NetworkConnection& connection) const + { + LOG_VERBOSE("Sending heartbeat"); + + NetworkPacket packet(NetworkCommand::Heartbeat); + connection.QueuePacket(std::move(packet)); + } + + NetworkStats NetworkBase::GetStats() const + { + NetworkStats stats = {}; + if (mode == NETWORK_MODE_CLIENT) + { + stats = _serverConnection->Stats; + } + else + { + for (auto& connection : client_connection_list) + { + for (size_t n = 0; n < EnumValue(NetworkStatisticsGroup::Max); n++) + { + stats.bytesReceived[n] += connection->Stats.bytesReceived[n]; + stats.bytesSent[n] += connection->Stats.bytesSent[n]; + } + } + } + return stats; + } + + void NetworkBase::ServerSendAuth(NetworkConnection& connection) + { + uint8_t new_playerid = 0; + if (connection.Player != nullptr) + { + new_playerid = connection.Player->Id; + } + NetworkPacket packet(NetworkCommand::Auth); + packet << static_cast(connection.AuthStatus) << new_playerid; + if (connection.AuthStatus == NetworkAuth::BadVersion) + { + packet.WriteString(NetworkGetVersion()); + } + connection.QueuePacket(std::move(packet)); + if (connection.AuthStatus != NetworkAuth::Ok && connection.AuthStatus != NetworkAuth::RequirePassword) + { + connection.Disconnect(); + } + } + + void NetworkBase::ServerSendMap(NetworkConnection* connection) + { + std::vector objects; + if (connection != nullptr) + { + objects = connection->RequestedObjects; + } + else + { + // This will send all custom objects to connected clients + // TODO: fix it so custom objects negotiation is performed even in this case. + auto& context = GetContext(); + auto& objManager = context.GetObjectManager(); + objects = objManager.GetPackableObjects(); + } + + auto header = SaveForNetwork(objects); + if (header.empty()) + { + if (connection != nullptr) + { + connection->SetLastDisconnectReason(STR_MULTIPLAYER_CONNECTION_CLOSED); + connection->Disconnect(); + } + return; + } + size_t chunksize = kChunkSize; + for (size_t i = 0; i < header.size(); i += chunksize) + { + size_t datasize = std::min(chunksize, header.size() - i); + NetworkPacket packet(NetworkCommand::Map); + packet << static_cast(header.size()) << static_cast(i); + packet.Write(&header[i], datasize); + if (connection != nullptr) + { + connection->QueuePacket(std::move(packet)); + } + else + { + SendPacketToClients(packet); + } + } + } + + std::vector NetworkBase::SaveForNetwork(const std::vector& objects) const + { + std::vector result; + auto ms = MemoryStream(); + if (SaveMap(&ms, objects)) + { + result.resize(ms.GetLength()); + std::memcpy(result.data(), ms.GetData(), result.size()); + } + else + { + LOG_WARNING("Failed to export map."); + } + return result; + } + + void NetworkBase::Client_Send_CHAT(const char* text) + { + NetworkPacket packet(NetworkCommand::Chat); + packet.WriteString(text); + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::ServerSendChat(const char* text, const std::vector& playerIds) + { + NetworkPacket packet(NetworkCommand::Chat); + packet.WriteString(text); + + if (playerIds.empty()) + { + // Empty players / default value means send to all players + SendPacketToClients(packet); + } + else + { + for (auto playerId : playerIds) + { + auto conn = GetPlayerConnection(playerId); + if (conn != nullptr) + { + conn->QueuePacket(packet); + } + } + } + } + + void NetworkBase::Client_Send_GAME_ACTION(const GameActions::GameAction* action) + { + NetworkPacket packet(NetworkCommand::GameAction); + + uint32_t networkId = 0; + networkId = ++_actionId; + + // I know its ugly, want basic functionality for now. + const_cast(action)->SetNetworkId(networkId); + if (action->GetCallback()) + { + _gameActionCallbacks.insert(std::make_pair(networkId, action->GetCallback())); + } + + DataSerialiser stream(true); + action->Serialise(stream); + + packet << getGameState().currentTicks << action->GetType() << stream; + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::ServerSendGameAction(const GameActions::GameAction* action) + { + NetworkPacket packet(NetworkCommand::GameAction); + + DataSerialiser stream(true); + action->Serialise(stream); + + packet << getGameState().currentTicks << action->GetType() << stream; + + SendPacketToClients(packet); + } + + void NetworkBase::ServerSendTick() + { + NetworkPacket packet(NetworkCommand::Tick); + packet << getGameState().currentTicks << ScenarioRandState().s0; + uint32_t flags = 0; + // Simple counter which limits how often a sprite checksum gets sent. + // This can get somewhat expensive, so we don't want to push it every tick in release, + // but debug version can check more often. + static int32_t checksum_counter = 0; + checksum_counter++; + if (checksum_counter >= 100) + { + checksum_counter = 0; + flags |= NETWORK_TICK_FLAG_CHECKSUMS; + } + // Send flags always, so we can understand packet structure on the other end, + // and allow for some expansion. + packet << flags; + if (flags & NETWORK_TICK_FLAG_CHECKSUMS) + { + EntitiesChecksum checksum = getGameState().entities.GetAllEntitiesChecksum(); + packet.WriteString(checksum.ToString()); + } + + SendPacketToClients(packet); + } + + void NetworkBase::ServerSendPlayerInfo(int32_t playerId) + { + NetworkPacket packet(NetworkCommand::PlayerInfo); + packet << getGameState().currentTicks; + + auto* player = GetPlayerByID(playerId); + if (player == nullptr) + return; + + player->Write(packet); + SendPacketToClients(packet); + } + + void NetworkBase::ServerSendPlayerList() + { + NetworkPacket packet(NetworkCommand::PlayerList); + packet << getGameState().currentTicks << static_cast(player_list.size()); + for (auto& player : player_list) + { + player->Write(packet); + } + SendPacketToClients(packet); + } + + void NetworkBase::Client_Send_PING() + { + NetworkPacket packet(NetworkCommand::Ping); + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::ServerSendPing() + { + last_ping_sent_time = Platform::GetTicks(); + NetworkPacket packet(NetworkCommand::Ping); + for (auto& client_connection : client_connection_list) + { + client_connection->PingTime = Platform::GetTicks(); + } + SendPacketToClients(packet, true); + } + + void NetworkBase::ServerSendPingList() + { + NetworkPacket packet(NetworkCommand::PingList); + packet << static_cast(player_list.size()); + for (auto& player : player_list) + { + packet << player->Id << player->Ping; + } + SendPacketToClients(packet); + } + + void NetworkBase::ServerSendSetDisconnectMsg(NetworkConnection& connection, const char* msg) + { + NetworkPacket packet(NetworkCommand::DisconnectMessage); + packet.WriteString(msg); + connection.QueuePacket(std::move(packet)); + } + + json_t NetworkBase::GetServerInfoAsJson() const + { + json_t jsonObj = { + { "name", Config::Get().network.ServerName }, + { "requiresPassword", _password.size() > 0 }, + { "version", NetworkGetVersion() }, + { "players", GetNumVisiblePlayers() }, + { "maxPlayers", Config::Get().network.Maxplayers }, + { "description", Config::Get().network.ServerDescription }, + { "greeting", Config::Get().network.ServerGreeting }, + { "dedicated", gOpenRCT2Headless }, }; - FormatStringLegacy(log_msg, 256, STR_LOG_SET_PLAYER_GROUP, args); - NetworkAppendServerLog(log_msg); + return jsonObj; } - return GameActions::Result(); -} -GameActions::Result NetworkModifyGroups( - NetworkPlayerId_t actionPlayerId, GameActions::ModifyGroupType type, uint8_t groupId, const std::string& name, - uint32_t permissionIndex, GameActions::PermissionState permissionState, bool isExecuting) -{ - auto& network = GetContext()->GetNetwork(); - switch (type) + void NetworkBase::ServerSendGameInfo(NetworkConnection& connection) { - case GameActions::ModifyGroupType::AddGroup: + NetworkPacket packet(NetworkCommand::GameInfo); + #ifndef DISABLE_HTTP + json_t jsonObj = GetServerInfoAsJson(); + + // Provider details + json_t jsonProvider = { + { "name", Config::Get().network.ProviderName }, + { "email", Config::Get().network.ProviderEmail }, + { "website", Config::Get().network.ProviderWebsite }, + }; + + jsonObj["provider"] = jsonProvider; + + packet.WriteString(jsonObj.dump()); + packet << _serverState.gamestateSnapshotsEnabled; + packet << IsServerPlayerInvisible; + + #endif + connection.QueuePacket(std::move(packet)); + } + + void NetworkBase::ServerSendShowError(NetworkConnection& connection, StringId title, StringId message) + { + NetworkPacket packet(NetworkCommand::ShowError); + packet << title << message; + connection.QueuePacket(std::move(packet)); + } + + void NetworkBase::ServerSendGroupList(NetworkConnection& connection) + { + NetworkPacket packet(NetworkCommand::GroupList); + packet << static_cast(group_list.size()) << default_group; + for (auto& group : group_list) { - if (isExecuting) - { - NetworkGroup* newgroup = network.AddGroup(); - if (newgroup == nullptr) - { - return GameActions::Result(GameActions::Status::Unknown, STR_CANT_DO_THIS, kStringIdNone); - } - } + group->Write(packet); } - break; - case GameActions::ModifyGroupType::RemoveGroup: + connection.QueuePacket(std::move(packet)); + } + + void NetworkBase::ServerSendEventPlayerJoined(const char* playerName) + { + NetworkPacket packet(NetworkCommand::Event); + packet << static_cast(SERVER_EVENT_PLAYER_JOINED); + packet.WriteString(playerName); + SendPacketToClients(packet); + } + + void NetworkBase::ServerSendEventPlayerDisconnected(const char* playerName, const char* reason) + { + NetworkPacket packet(NetworkCommand::Event); + packet << static_cast(SERVER_EVENT_PLAYER_DISCONNECTED); + packet.WriteString(playerName); + packet.WriteString(reason); + SendPacketToClients(packet); + } + + bool NetworkBase::ProcessConnection(NetworkConnection& connection) + { + NetworkReadPacket packetStatus; + + uint32_t countProcessed = 0; + do { - if (groupId == 0) + countProcessed++; + packetStatus = connection.ReadPacket(); + switch (packetStatus) { - return GameActions::Result(GameActions::Status::Disallowed, STR_THIS_GROUP_CANNOT_BE_MODIFIED, kStringIdNone); - } - for (const auto& it : network.player_list) - { - if ((it.get())->Group == groupId) - { - return GameActions::Result( - GameActions::Status::Disallowed, STR_CANT_REMOVE_GROUP_THAT_PLAYERS_BELONG_TO, kStringIdNone); - } - } - if (isExecuting) - { - network.RemoveGroup(groupId); - } - } - break; - case GameActions::ModifyGroupType::SetPermissions: - { - if (groupId == 0) - { // can't change admin group permissions - return GameActions::Result(GameActions::Status::Disallowed, STR_THIS_GROUP_CANNOT_BE_MODIFIED, kStringIdNone); - } - NetworkGroup* mygroup = nullptr; - NetworkPlayer* player = network.GetPlayerByID(actionPlayerId); - auto networkPermission = static_cast(permissionIndex); - if (player != nullptr && permissionState == GameActions::PermissionState::Toggle) - { - mygroup = network.GetGroupByID(player->Group); - if (mygroup == nullptr || !mygroup->CanPerformAction(networkPermission)) - { - return GameActions::Result( - GameActions::Status::Disallowed, STR_CANT_MODIFY_PERMISSION_THAT_YOU_DO_NOT_HAVE_YOURSELF, - kStringIdNone); - } - } - if (isExecuting) - { - NetworkGroup* group = network.GetGroupByID(groupId); - if (group != nullptr) - { - if (permissionState != GameActions::PermissionState::Toggle) + case NetworkReadPacket::Disconnected: + // closed connection or network error + if (!connection.GetLastDisconnectReason()) { - if (mygroup != nullptr) + connection.SetLastDisconnectReason(STR_MULTIPLAYER_CONNECTION_CLOSED); + } + return false; + case NetworkReadPacket::Success: + // done reading in packet + ProcessPacket(connection, connection.InboundPacket); + if (!connection.IsValid()) + { + return false; + } + break; + case NetworkReadPacket::MoreData: + // more data required to be read + break; + case NetworkReadPacket::NoData: + // could not read anything from socket + break; + } + } while (packetStatus == NetworkReadPacket::Success && countProcessed < kMaxPacketsPerUpdate); + + if (!connection.ReceivedPacketRecently()) + { + if (!connection.GetLastDisconnectReason()) + { + connection.SetLastDisconnectReason(STR_MULTIPLAYER_NO_DATA); + } + return false; + } + + return true; + } + + void NetworkBase::ProcessPacket(NetworkConnection& connection, NetworkPacket& packet) + { + const auto& handlerList = GetMode() == NETWORK_MODE_SERVER ? server_command_handlers : client_command_handlers; + + auto it = handlerList.find(packet.GetCommand()); + if (it != handlerList.end()) + { + auto commandHandler = it->second; + if (connection.AuthStatus == NetworkAuth::Ok || !packet.CommandRequiresAuth()) + { + try + { + (this->*commandHandler)(connection, packet); + } + catch (const std::exception& ex) + { + LOG_VERBOSE("Exception during packet processing: %s", ex.what()); + } + } + } + + packet.Clear(); + } + + // This is called at the end of each game tick, this where things should be processed that affects the game state. + void NetworkBase::ProcessPending() + { + if (GetMode() == NETWORK_MODE_SERVER) + { + ProcessDisconnectedClients(); + } + else if (GetMode() == NETWORK_MODE_CLIENT) + { + ProcessPlayerInfo(); + } + ProcessPlayerList(); + } + + static bool ProcessPlayerAuthenticatePluginHooks( + const NetworkConnection& connection, std::string_view name, std::string_view publicKeyHash) + { + #ifdef ENABLE_SCRIPTING + using namespace OpenRCT2::Scripting; + + auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); + if (hookEngine.HasSubscriptions(Scripting::HookType::networkAuthenticate)) + { + auto ctx = GetContext()->GetScriptEngine().GetContext(); + + // Create event args object + DukObject eObj(ctx); + eObj.Set("name", name); + eObj.Set("publicKeyHash", publicKeyHash); + eObj.Set("ipAddress", connection.Socket->GetIpAddress()); + eObj.Set("cancel", false); + auto e = eObj.Take(); + + // Call the subscriptions + hookEngine.Call(Scripting::HookType::networkAuthenticate, e, false); + + // Check if any hook has cancelled the join + if (AsOrDefault(e["cancel"], false)) + { + return false; + } + } + #endif + return true; + } + + static void ProcessPlayerJoinedPluginHooks(uint8_t playerId) + { + #ifdef ENABLE_SCRIPTING + using namespace OpenRCT2::Scripting; + + auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); + if (hookEngine.HasSubscriptions(Scripting::HookType::networkJoin)) + { + auto ctx = GetContext()->GetScriptEngine().GetContext(); + + // Create event args object + DukObject eObj(ctx); + eObj.Set("player", playerId); + auto e = eObj.Take(); + + // Call the subscriptions + hookEngine.Call(Scripting::HookType::networkJoin, e, false); + } + #endif + } + + static void ProcessPlayerLeftPluginHooks(uint8_t playerId) + { + #ifdef ENABLE_SCRIPTING + using namespace OpenRCT2::Scripting; + + auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); + if (hookEngine.HasSubscriptions(Scripting::HookType::networkLeave)) + { + auto ctx = GetContext()->GetScriptEngine().GetContext(); + + // Create event args object + DukObject eObj(ctx); + eObj.Set("player", playerId); + auto e = eObj.Take(); + + // Call the subscriptions + hookEngine.Call(Scripting::HookType::networkLeave, e, false); + } + #endif + } + + void NetworkBase::ProcessPlayerList() + { + if (GetMode() == NETWORK_MODE_SERVER) + { + // Avoid sending multiple times the player list, we mark the list invalidated on modifications + // and then send at the end of the tick the final player list. + if (_playerListInvalidated) + { + _playerListInvalidated = false; + ServerSendPlayerList(); + } + } + else + { + // As client we have to keep things in order so the update is tick bound. + // Commands/Actions reference players and so this list needs to be in sync with those. + auto itPending = _pendingPlayerLists.begin(); + while (itPending != _pendingPlayerLists.end()) + { + if (itPending->first > getGameState().currentTicks) + break; + + // List of active players found in the list. + std::vector activePlayerIds; + std::vector newPlayers; + std::vector removedPlayers; + + for (const auto& pendingPlayer : itPending->second.players) + { + activePlayerIds.push_back(pendingPlayer.Id); + + auto* player = GetPlayerByID(pendingPlayer.Id); + if (player == nullptr) + { + // Add new player. + player = AddPlayer("", ""); + if (player != nullptr) { - if (permissionState == GameActions::PermissionState::SetAll) + *player = pendingPlayer; + if (player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) { - group->ActionsAllowed = mygroup->ActionsAllowed; - } - else - { - group->ActionsAllowed.fill(0x00); + _serverConnection->Player = player; } + newPlayers.push_back(player->Id); } } else { - group->ToggleActionPermission(networkPermission); + // Update. + *player = pendingPlayer; + } + } + + // Remove any players that are not in newly received list + for (const auto& player : player_list) + { + if (std::find(activePlayerIds.begin(), activePlayerIds.end(), player->Id) == activePlayerIds.end()) + { + removedPlayers.push_back(player->Id); + } + } + + // Run player removed hooks (must be before players removed from list) + for (auto playerId : removedPlayers) + { + ProcessPlayerLeftPluginHooks(playerId); + } + + // Run player joined hooks (must be after players added to list) + for (auto playerId : newPlayers) + { + ProcessPlayerJoinedPluginHooks(playerId); + } + + // Now actually remove removed players from player list + player_list.erase( + std::remove_if( + player_list.begin(), player_list.end(), + [&removedPlayers](const std::unique_ptr& player) { + return std::find(removedPlayers.begin(), removedPlayers.end(), player->Id) != removedPlayers.end(); + }), + player_list.end()); + + _pendingPlayerLists.erase(itPending); + itPending = _pendingPlayerLists.begin(); + } + } + } + + void NetworkBase::ProcessPlayerInfo() + { + const auto currentTicks = getGameState().currentTicks; + + auto range = _pendingPlayerInfo.equal_range(currentTicks); + for (auto it = range.first; it != range.second; it++) + { + auto* player = GetPlayerByID(it->second.Id); + if (player != nullptr) + { + const NetworkPlayer& networkedInfo = it->second; + player->Flags = networkedInfo.Flags; + player->Group = networkedInfo.Group; + player->LastAction = networkedInfo.LastAction; + player->LastActionCoord = networkedInfo.LastActionCoord; + player->MoneySpent = networkedInfo.MoneySpent; + player->CommandsRan = networkedInfo.CommandsRan; + } + } + _pendingPlayerInfo.erase(currentTicks); + } + + void NetworkBase::ProcessDisconnectedClients() + { + for (auto it = client_connection_list.begin(); it != client_connection_list.end();) + { + auto& connection = *it; + + if (!connection->ShouldDisconnect) + { + it++; + continue; + } + + // Make sure to send all remaining packets out before disconnecting. + connection->SendQueuedData(); + connection->Socket->Disconnect(); + + ServerClientDisconnected(connection); + RemovePlayer(connection); + + it = client_connection_list.erase(it); + } + } + + void NetworkBase::AddClient(std::unique_ptr&& socket) + { + // Log connection info. + char addr[128]; + snprintf(addr, sizeof(addr), "Client joined from %s", socket->GetHostName()); + AppendServerLog(addr); + + // Store connection + auto connection = std::make_unique(); + connection->Socket = std::move(socket); + + client_connection_list.push_back(std::move(connection)); + } + + void NetworkBase::ServerClientDisconnected(std::unique_ptr& connection) + { + NetworkPlayer* connection_player = connection->Player; + if (connection_player == nullptr) + return; + + char text[256]; + const char* has_disconnected_args[2] = { + connection_player->Name.c_str(), + connection->GetLastDisconnectReason(), + }; + if (has_disconnected_args[1] != nullptr) + { + FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_WITH_REASON, has_disconnected_args); + } + else + { + FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_NO_REASON, &(has_disconnected_args[0])); + } + + ChatAddHistory(text); + Peep* pickup_peep = NetworkGetPickupPeep(connection_player->Id); + if (pickup_peep != nullptr) + { + GameActions::PeepPickupAction pickupAction{ GameActions::PeepPickupType::Cancel, + pickup_peep->Id, + { NetworkGetPickupPeepOldX(connection_player->Id), 0, 0 }, + NetworkGetCurrentPlayerId() }; + auto res = GameActions::Execute(&pickupAction); + } + ServerSendEventPlayerDisconnected( + const_cast(connection_player->Name.c_str()), connection->GetLastDisconnectReason()); + + // Log player disconnected event + AppendServerLog(text); + + ProcessPlayerLeftPluginHooks(connection_player->Id); + } + + void NetworkBase::RemovePlayer(std::unique_ptr& connection) + { + NetworkPlayer* connection_player = connection->Player; + if (connection_player == nullptr) + return; + + player_list.erase( + std::remove_if( + player_list.begin(), player_list.end(), + [connection_player](std::unique_ptr& player) { return player.get() == connection_player; }), + player_list.end()); + + // Send new player list. + _playerListInvalidated = true; + } + + NetworkPlayer* NetworkBase::AddPlayer(const std::string& name, const std::string& keyhash) + { + NetworkPlayer* addedplayer = nullptr; + int32_t newid = -1; + if (GetMode() == NETWORK_MODE_SERVER) + { + // Find first unused player id + for (int32_t id = 0; id < 255; id++) + { + if (std::find_if( + player_list.begin(), player_list.end(), + [&id](std::unique_ptr const& player) { return player->Id == id; }) + == player_list.end()) + { + newid = id; + break; + } + } + } + else + { + newid = 0; + } + if (newid != -1) + { + std::unique_ptr player; + if (GetMode() == NETWORK_MODE_SERVER) + { + // Load keys host may have added manually + _userManager.Load(); + + // Check if the key is registered + const NetworkUser* networkUser = _userManager.GetUserByHash(keyhash); + + player = std::make_unique(); + player->Id = newid; + player->KeyHash = keyhash; + if (networkUser == nullptr) + { + player->Group = GetDefaultGroup(); + if (!name.empty()) + { + player->SetName(MakePlayerNameUnique(String::trim(name))); + } + } + else + { + player->Group = networkUser->GroupId.has_value() ? *networkUser->GroupId : GetDefaultGroup(); + player->SetName(networkUser->Name); + } + + // Send new player list. + _playerListInvalidated = true; + } + else + { + player = std::make_unique(); + player->Id = newid; + player->Group = GetDefaultGroup(); + player->SetName(String::trim(std::string(name))); + } + + addedplayer = player.get(); + player_list.push_back(std::move(player)); + } + return addedplayer; + } + + std::string NetworkBase::MakePlayerNameUnique(const std::string& name) + { + // Note: Player names are case-insensitive + + std::string new_name = name.substr(0, 31); + int32_t counter = 1; + bool unique; + do + { + unique = true; + + // Check if there is already a player with this name in the server + for (const auto& player : player_list) + { + if (String::iequals(player->Name, new_name)) + { + unique = false; + break; + } + } + + if (unique) + { + // Check if there is already a registered player with this name + if (_userManager.GetUserByName(new_name) != nullptr) + { + unique = false; + } + } + + if (!unique) + { + // Increment name counter + counter++; + new_name = name.substr(0, 31) + " #" + std::to_string(counter); + } + } while (!unique); + return new_name; + } + + void NetworkBase::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet) + { + auto keyPath = NetworkGetPrivateKeyPath(Config::Get().network.PlayerName); + if (!File::Exists(keyPath)) + { + LOG_ERROR("Key file (%s) was not found. Restart client to re-generate it.", keyPath.c_str()); + return; + } + + try + { + auto fs = FileStream(keyPath, FileMode::open); + if (!_key.LoadPrivate(&fs)) + { + throw std::runtime_error("Failed to load private key."); + } + } + catch (const std::exception&) + { + LOG_ERROR("Failed to load key %s", keyPath.c_str()); + connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); + connection.Disconnect(); + return; + } + + uint32_t challenge_size; + packet >> challenge_size; + const char* challenge = reinterpret_cast(packet.Read(challenge_size)); + + std::vector signature; + const std::string pubkey = _key.PublicKeyString(); + _challenge.resize(challenge_size); + std::memcpy(_challenge.data(), challenge, challenge_size); + bool ok = _key.Sign(_challenge.data(), _challenge.size(), signature); + if (!ok) + { + LOG_ERROR("Failed to sign server's challenge."); + connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); + connection.Disconnect(); + return; + } + // Don't keep private key in memory. There's no need and it may get leaked + // when process dump gets collected at some point in future. + _key.Unload(); + + Client_Send_AUTH(Config::Get().network.PlayerName, gCustomPassword, pubkey, signature); + } + + void NetworkBase::ServerHandleRequestGamestate(NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t tick; + packet >> tick; + + if (_serverState.gamestateSnapshotsEnabled == false) + { + // Ignore this if this is off. + return; + } + + IGameStateSnapshots* snapshots = GetContext().GetGameStateSnapshots(); + + const GameStateSnapshot_t* snapshot = snapshots->GetLinkedSnapshot(tick); + if (snapshot != nullptr) + { + MemoryStream snapshotMemory; + DataSerialiser ds(true, snapshotMemory); + + snapshots->SerialiseSnapshot(const_cast(*snapshot), ds); + + uint32_t bytesSent = 0; + uint32_t length = static_cast(snapshotMemory.GetLength()); + while (bytesSent < length) + { + uint32_t dataSize = kChunkSize; + if (bytesSent + dataSize > snapshotMemory.GetLength()) + { + dataSize = snapshotMemory.GetLength() - bytesSent; + } + + NetworkPacket packetGameStateChunk(NetworkCommand::GameState); + packetGameStateChunk << tick << length << bytesSent << dataSize; + packetGameStateChunk.Write(static_cast(snapshotMemory.GetData()) + bytesSent, dataSize); + + connection.QueuePacket(std::move(packetGameStateChunk)); + + bytesSent += dataSize; + } + } + } + + void NetworkBase::ServerHandleHeartbeat(NetworkConnection& connection, NetworkPacket& packet) + { + LOG_VERBOSE("Client %s heartbeat", connection.Socket->GetHostName()); + connection.ResetLastPacketTime(); + } + + void NetworkBase::Client_Handle_AUTH(NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t auth_status; + packet >> auth_status >> const_cast(player_id); + connection.AuthStatus = static_cast(auth_status); + switch (connection.AuthStatus) + { + case NetworkAuth::Ok: + Client_Send_GAMEINFO(); + break; + case NetworkAuth::BadName: + connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PLAYER_NAME); + connection.Disconnect(); + break; + case NetworkAuth::BadVersion: + { + auto version = std::string(packet.ReadString()); + auto versionp = version.c_str(); + connection.SetLastDisconnectReason(STR_MULTIPLAYER_INCORRECT_SOFTWARE_VERSION, &versionp); + connection.Disconnect(); + break; + } + case NetworkAuth::BadPassword: + connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PASSWORD); + connection.Disconnect(); + break; + case NetworkAuth::VerificationFailure: + connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE); + connection.Disconnect(); + break; + case NetworkAuth::Full: + connection.SetLastDisconnectReason(STR_MULTIPLAYER_SERVER_FULL); + connection.Disconnect(); + break; + case NetworkAuth::RequirePassword: + ContextOpenWindowView(WV_NETWORK_PASSWORD); + break; + case NetworkAuth::UnknownKeyDisallowed: + connection.SetLastDisconnectReason(STR_MULTIPLAYER_UNKNOWN_KEY_DISALLOWED); + connection.Disconnect(); + break; + default: + connection.SetLastDisconnectReason(STR_MULTIPLAYER_RECEIVED_INVALID_DATA); + connection.Disconnect(); + break; + } + } + + void NetworkBase::ServerClientJoined(std::string_view name, const std::string& keyhash, NetworkConnection& connection) + { + auto player = AddPlayer(std::string(name), keyhash); + connection.Player = player; + if (player != nullptr) + { + char text[256]; + const char* player_name = static_cast(player->Name.c_str()); + FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_JOINED_THE_GAME, &player_name); + ChatAddHistory(text); + + auto& context = GetContext(); + auto& objManager = context.GetObjectManager(); + auto objects = objManager.GetPackableObjects(); + ServerSendObjectsList(connection, objects); + ServerSendScripts(connection); + + // Log player joining event + std::string playerNameHash = player->Name + " (" + keyhash + ")"; + player_name = static_cast(playerNameHash.c_str()); + FormatStringLegacy(text, 256, STR_MULTIPLAYER_PLAYER_HAS_JOINED_THE_GAME, &player_name); + AppendServerLog(text); + + ProcessPlayerJoinedPluginHooks(player->Id); + } + } + + void NetworkBase::ServerHandleToken(NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) + { + uint8_t token_size = 10 + (rand() & 0x7f); + connection.Challenge.resize(token_size); + for (int32_t i = 0; i < token_size; i++) + { + connection.Challenge[i] = static_cast(rand() & 0xff); + } + ServerSendToken(connection); + } + + static void OpenNetworkProgress(StringId captionStringId) + { + auto captionString = GetContext()->GetLocalisationService().GetString(captionStringId); + auto intent = Intent(INTENT_ACTION_PROGRESS_OPEN); + intent.PutExtra(INTENT_EXTRA_MESSAGE, captionString); + intent.PutExtra(INTENT_EXTRA_CALLBACK, []() -> void { ::GetContext()->GetNetwork().Close(); }); + ContextOpenIntent(&intent); + } + + void NetworkBase::Client_Handle_OBJECTS_LIST(NetworkConnection& connection, NetworkPacket& packet) + { + auto& repo = GetContext().GetObjectRepository(); + + uint32_t index = 0; + uint32_t totalObjects = 0; + packet >> index >> totalObjects; + + static constexpr uint32_t kObjectStartIndex = 0; + if (index == kObjectStartIndex) + { + _missingObjects.clear(); + } + + if (totalObjects > 0) + { + OpenNetworkProgress(STR_MULTIPLAYER_RECEIVING_OBJECTS_LIST); + GetContext().SetProgress(index + 1, totalObjects); + + uint8_t objectType{}; + packet >> objectType; + + if (objectType == 0) + { + // DAT + auto entry = reinterpret_cast(packet.Read(sizeof(RCTObjectEntry))); + if (entry != nullptr) + { + const auto* object = repo.FindObject(entry); + if (object == nullptr) + { + auto objectName = std::string(entry->GetName()); + LOG_VERBOSE("Requesting object %s with checksum %x from server", objectName.c_str(), entry->checksum); + _missingObjects.push_back(ObjectEntryDescriptor(*entry)); + } + else if (object->ObjectEntry.checksum != entry->checksum || object->ObjectEntry.flags != entry->flags) + { + auto objectName = std::string(entry->GetName()); + LOG_WARNING( + "Object %s has different checksum/flags (%x/%x) than server (%x/%x).", objectName.c_str(), + object->ObjectEntry.checksum, object->ObjectEntry.flags, entry->checksum, entry->flags); + } + } + } + else + { + // JSON + auto identifier = packet.ReadString(); + if (!identifier.empty()) + { + const auto* object = repo.FindObject(identifier); + if (object == nullptr) + { + auto objectName = std::string(identifier); + LOG_VERBOSE("Requesting object %s from server", objectName.c_str()); + _missingObjects.push_back(ObjectEntryDescriptor(objectName)); } } } } - break; - case GameActions::ModifyGroupType::SetName: + + if (index + 1 >= totalObjects) { - NetworkGroup* group = network.GetGroupByID(groupId); - if (group == nullptr) + LOG_VERBOSE("client received object list, it has %u entries", totalObjects); + Client_Send_MAPREQUEST(_missingObjects); + _missingObjects.clear(); + } + } + + void NetworkBase::Client_Handle_SCRIPTS_HEADER(NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t numScripts{}; + uint32_t dataSize{}; + packet >> numScripts >> dataSize; + + #ifdef ENABLE_SCRIPTING + _serverScriptsData.data.Clear(); + _serverScriptsData.pluginCount = numScripts; + _serverScriptsData.dataSize = dataSize; + #else + if (numScripts > 0) + { + connection.SetLastDisconnectReason("The client requires plugin support."); + Close(); + } + #endif + } + + void NetworkBase::Client_Handle_SCRIPTS_DATA(NetworkConnection& connection, NetworkPacket& packet) + { + #ifdef ENABLE_SCRIPTING + uint32_t dataSize{}; + packet >> dataSize; + Guard::Assert(dataSize > 0); + + const auto* data = packet.Read(dataSize); + Guard::Assert(data != nullptr); + + auto& scriptsData = _serverScriptsData.data; + scriptsData.Write(data, dataSize); + + if (scriptsData.GetLength() == _serverScriptsData.dataSize) + { + auto& scriptEngine = GetContext().GetScriptEngine(); + + scriptsData.SetPosition(0); + for (uint32_t i = 0; i < _serverScriptsData.pluginCount; ++i) { - return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_RENAME_GROUP, kStringIdNone); + const auto codeSize = scriptsData.ReadValue(); + const auto scriptData = scriptsData.ReadArray(codeSize); + + auto code = std::string_view(reinterpret_cast(scriptData.get()), codeSize); + scriptEngine.AddNetworkPlugin(code); } + Guard::Assert(scriptsData.GetPosition() == scriptsData.GetLength()); - const char* oldName = group->GetName().c_str(); + // Empty the current buffer. + _serverScriptsData = {}; + } + #else + connection.SetLastDisconnectReason("The client requires plugin support."); + Close(); + #endif + } - if (strcmp(oldName, name.c_str()) == 0) + void NetworkBase::Client_Handle_GAMESTATE(NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t tick; + uint32_t totalSize; + uint32_t offset; + uint32_t dataSize; + + packet >> tick >> totalSize >> offset >> dataSize; + + if (offset == 0) + { + // Reset + _serverGameState = MemoryStream(); + } + + _serverGameState.SetPosition(offset); + + const uint8_t* data = packet.Read(dataSize); + _serverGameState.Write(data, dataSize); + + LOG_VERBOSE( + "Received Game State %.02f%%", + (static_cast(_serverGameState.GetLength()) / static_cast(totalSize)) * 100.0f); + + if (_serverGameState.GetLength() == totalSize) + { + _serverGameState.SetPosition(0); + DataSerialiser ds(false, _serverGameState); + + IGameStateSnapshots* snapshots = GetContext().GetGameStateSnapshots(); + + GameStateSnapshot_t& serverSnapshot = snapshots->CreateSnapshot(); + snapshots->SerialiseSnapshot(serverSnapshot, ds); + + const GameStateSnapshot_t* desyncSnapshot = snapshots->GetLinkedSnapshot(tick); + if (desyncSnapshot != nullptr) { - return GameActions::Result(); - } + GameStateCompareData cmpData = snapshots->Compare(serverSnapshot, *desyncSnapshot); - if (name.empty()) - { - return GameActions::Result( - GameActions::Status::InvalidParameters, STR_CANT_RENAME_GROUP, STR_INVALID_GROUP_NAME); - } + std::string outputPath = GetContext().GetPlatformEnvironment().GetDirectoryPath( + DirBase::user, DirId::desyncLogs); - if (isExecuting) - { - if (group != nullptr) + Path::CreateDirectory(outputPath); + + char uniqueFileName[128] = {}; + snprintf( + uniqueFileName, sizeof(uniqueFileName), "desync_%llu_%u.txt", + static_cast(Platform::GetDatetimeNowUTC()), tick); + + std::string outputFile = Path::Combine(outputPath, uniqueFileName); + + if (snapshots->LogCompareDataToFile(outputFile, cmpData)) { - group->SetName(name); + LOG_INFO("Wrote desync report to '%s'", outputFile.c_str()); + + auto ft = Formatter(); + ft.Add(uniqueFileName); + + char str_desync[1024]; + FormatStringLegacy(str_desync, sizeof(str_desync), STR_DESYNC_REPORT, ft.Data()); + + auto intent = Intent(WindowClass::NetworkStatus); + intent.PutExtra(INTENT_EXTRA_MESSAGE, std::string{ str_desync }); + ContextOpenIntent(&intent); } } } - break; - case GameActions::ModifyGroupType::SetDefault: + } + + void NetworkBase::ServerHandleMapRequest(NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t size; + packet >> size; + LOG_VERBOSE("Client requested %u objects", size); + auto& repo = GetContext().GetObjectRepository(); + for (uint32_t i = 0; i < size; i++) { - if (groupId == 0) + uint8_t generation{}; + packet >> generation; + + std::string objectName; + const ObjectRepositoryItem* item{}; + if (generation == static_cast(ObjectGeneration::DAT)) { - return GameActions::Result(GameActions::Status::Disallowed, STR_CANT_SET_TO_THIS_GROUP, kStringIdNone); + const auto* entry = reinterpret_cast(packet.Read(sizeof(RCTObjectEntry))); + objectName = std::string(entry->GetName()); + LOG_VERBOSE("Client requested object %s", objectName.c_str()); + item = repo.FindObject(entry); } - if (isExecuting) + else { - network.SetDefaultGroup(groupId); + objectName = std::string(packet.ReadString()); + LOG_VERBOSE("Client requested object %s", objectName.c_str()); + item = repo.FindObject(objectName); + } + + if (item == nullptr) + { + LOG_WARNING("Client tried getting non-existent object %s from us.", objectName.c_str()); + } + else + { + connection.RequestedObjects.push_back(item); } } - break; - default: - LOG_ERROR("Invalid Modify Group Type: %u", static_cast(type)); + + auto player_name = connection.Player->Name.c_str(); + ServerSendMap(&connection); + ServerSendEventPlayerJoined(player_name); + ServerSendGroupList(connection); + } + + void NetworkBase::ServerHandleAuth(NetworkConnection& connection, NetworkPacket& packet) + { + if (connection.AuthStatus != NetworkAuth::Ok) + { + auto* hostName = connection.Socket->GetHostName(); + auto gameversion = packet.ReadString(); + auto name = packet.ReadString(); + auto password = packet.ReadString(); + auto pubkey = packet.ReadString(); + uint32_t sigsize; + packet >> sigsize; + if (pubkey.empty()) + { + connection.AuthStatus = NetworkAuth::VerificationFailure; + } + else + { + try + { + // RSA technically supports keys up to 65536 bits, so this is the + // maximum signature size for now. + constexpr auto MaxRSASignatureSizeInBytes = 8192; + + if (sigsize == 0 || sigsize > MaxRSASignatureSizeInBytes) + { + throw std::runtime_error("Invalid signature size"); + } + + std::vector signature; + signature.resize(sigsize); + + const uint8_t* signatureData = packet.Read(sigsize); + if (signatureData == nullptr) + { + throw std::runtime_error("Failed to read packet."); + } + + std::memcpy(signature.data(), signatureData, sigsize); + + auto ms = MemoryStream(pubkey.data(), pubkey.size()); + if (!connection.Key.LoadPublic(&ms)) + { + throw std::runtime_error("Failed to load public key."); + } + + bool verified = connection.Key.Verify(connection.Challenge.data(), connection.Challenge.size(), signature); + const std::string hash = connection.Key.PublicKeyHash(); + if (verified) + { + LOG_VERBOSE("Connection %s: Signature verification ok. Hash %s", hostName, hash.c_str()); + if (Config::Get().network.KnownKeysOnly && _userManager.GetUserByHash(hash) == nullptr) + { + LOG_VERBOSE("Connection %s: Hash %s, not known", hostName, hash.c_str()); + connection.AuthStatus = NetworkAuth::UnknownKeyDisallowed; + } + else + { + connection.AuthStatus = NetworkAuth::Verified; + } + } + else + { + connection.AuthStatus = NetworkAuth::VerificationFailure; + LOG_VERBOSE("Connection %s: Signature verification failed!", hostName); + } + } + catch (const std::exception&) + { + connection.AuthStatus = NetworkAuth::VerificationFailure; + LOG_VERBOSE("Connection %s: Signature verification failed, invalid data!", hostName); + } + } + + bool passwordless = false; + if (connection.AuthStatus == NetworkAuth::Verified) + { + const NetworkGroup* group = GetGroupByID(GetGroupIDByHash(connection.Key.PublicKeyHash())); + if (group != nullptr) + { + passwordless = group->CanPerformAction(NetworkPermission::PasswordlessLogin); + } + } + if (gameversion != NetworkGetVersion()) + { + connection.AuthStatus = NetworkAuth::BadVersion; + LOG_INFO("Connection %s: Bad version.", hostName); + } + else if (name.empty()) + { + connection.AuthStatus = NetworkAuth::BadName; + LOG_INFO("Connection %s: Bad name.", connection.Socket->GetHostName()); + } + else if (!passwordless) + { + if (password.empty() && !_password.empty()) + { + connection.AuthStatus = NetworkAuth::RequirePassword; + LOG_INFO("Connection %s: Requires password.", hostName); + } + else if (!password.empty() && _password != password) + { + connection.AuthStatus = NetworkAuth::BadPassword; + LOG_INFO("Connection %s: Bad password.", hostName); + } + } + + if (GetNumVisiblePlayers() >= Config::Get().network.Maxplayers) + { + connection.AuthStatus = NetworkAuth::Full; + LOG_INFO("Connection %s: Server is full.", hostName); + } + else if (connection.AuthStatus == NetworkAuth::Verified) + { + const std::string hash = connection.Key.PublicKeyHash(); + if (ProcessPlayerAuthenticatePluginHooks(connection, name, hash)) + { + connection.AuthStatus = NetworkAuth::Ok; + ServerClientJoined(name, hash, connection); + } + else + { + connection.AuthStatus = NetworkAuth::VerificationFailure; + LOG_INFO("Connection %s: Denied by plugin.", hostName); + } + } + + ServerSendAuth(connection); + } + } + + void NetworkBase::Client_Handle_MAP([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t size, offset; + packet >> size >> offset; + int32_t chunksize = static_cast(packet.Header.Size - packet.BytesRead); + if (chunksize <= 0) + { + return; + } + if (offset == 0) + { + // Start of a new map load, clear the queue now as we have to buffer them + // until the map is fully loaded. + GameActions::ClearQueue(); + GameActions::SuspendQueue(); + + _serverTickData.clear(); + _clientMapLoaded = false; + + OpenNetworkProgress(STR_MULTIPLAYER_DOWNLOADING_MAP); + } + if (size > chunk_buffer.size()) + { + chunk_buffer.resize(size); + } + + const auto currentProgressKiB = (offset + chunksize) / 1024; + const auto totalSizeKiB = size / 1024; + + GetContext().SetProgress(currentProgressKiB, totalSizeKiB, STR_STRING_M_OF_N_KIB); + + std::memcpy(&chunk_buffer[offset], const_cast(static_cast(packet.Read(chunksize))), chunksize); + if (offset + chunksize == size) + { + // Allow queue processing of game actions again. + GameActions::ResumeQueue(); + + ContextForceCloseWindowByClass(WindowClass::ProgressWindow); + GameUnloadScripts(); + GameNotifyMapChange(); + + bool has_to_free = false; + uint8_t* data = &chunk_buffer[0]; + size_t data_size = size; + auto ms = MemoryStream(data, data_size); + if (LoadMap(&ms)) + { + GameLoadInit(); + GameLoadScripts(); + GameNotifyMapChanged(); + _serverState.tick = getGameState().currentTicks; + // NetworkStatusOpen("Loaded new map from network"); + _serverState.state = NetworkServerStatus::Ok; + _clientMapLoaded = true; + gFirstTimeSaving = true; + + // Notify user he is now online and which shortcut key enables chat + NetworkChatShowConnectedMessage(); + + // Fix invalid vehicle sprite sizes, thus preventing visual corruption of sprites + FixInvalidVehicleSpriteSizes(); + + // NOTE: Game actions are normally processed before processing the player list. + // Given that during map load game actions are buffered we have to process the + // player list first to have valid players for the queued game actions. + ProcessPlayerList(); + } + else + { + // Something went wrong, game is not loaded. Return to main screen. + auto loadOrQuitAction = GameActions::LoadOrQuitAction( + GameActions::LoadOrQuitModes::OpenSavePrompt, PromptMode::saveBeforeQuit); + GameActions::Execute(&loadOrQuitAction); + } + if (has_to_free) + { + free(data); + } + } + } + + bool NetworkBase::LoadMap(IStream* stream) + { + bool result = false; + try + { + auto& context = GetContext(); + auto& objManager = context.GetObjectManager(); + auto importer = ParkImporter::CreateParkFile(context.GetObjectRepository()); + auto loadResult = importer->LoadFromStream(stream, false); + objManager.LoadObjects(loadResult.RequiredObjects); + + MapAnimations::ClearAll(); + // TODO: Have a separate GameState and exchange once loaded. + auto& gameState = getGameState(); + importer->Import(gameState); + + EntityTweener::Get().Reset(); + MapAnimations::MarkAllTiles(); + + gLastAutoSaveUpdate = kAutosavePause; + result = true; + } + catch (const std::exception& e) + { + Console::Error::WriteLine("Unable to read map from server: %s", e.what()); + } + return result; + } + + bool NetworkBase::SaveMap(IStream* stream, const std::vector& objects) const + { + bool result = false; + PrepareMapForSave(); + try + { + auto exporter = std::make_unique(); + exporter->ExportObjectsList = objects; + + auto& gameState = getGameState(); + exporter->Export(gameState, *stream, kParkFileNetCompressionLevel); + result = true; + } + catch (const std::exception& e) + { + Console::Error::WriteLine("Unable to serialise map: %s", e.what()); + } + return result; + } + + void NetworkBase::Client_Handle_CHAT([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + auto text = packet.ReadString(); + if (!text.empty()) + { + ChatAddHistory(std::string(text)); + } + } + + static bool ProcessChatMessagePluginHooks(uint8_t playerId, std::string& text) + { + #ifdef ENABLE_SCRIPTING + auto& hookEngine = GetContext()->GetScriptEngine().GetHookEngine(); + if (hookEngine.HasSubscriptions(Scripting::HookType::networkChat)) + { + auto ctx = GetContext()->GetScriptEngine().GetContext(); + + // Create event args object + auto objIdx = duk_push_object(ctx); + duk_push_number(ctx, playerId); + duk_put_prop_string(ctx, objIdx, "player"); + duk_push_string(ctx, text.c_str()); + duk_put_prop_string(ctx, objIdx, "message"); + auto e = DukValue::take_from_stack(ctx); + + // Call the subscriptions + hookEngine.Call(Scripting::HookType::networkChat, e, false); + + // Update text from object if subscriptions changed it + if (e["message"].type() != DukValue::Type::STRING) + { + // Subscription set text to non-string, do not relay message + return false; + } + text = e["message"].as_string(); + if (text.empty()) + { + // Subscription set text to empty string, do not relay message + return false; + } + } + #endif + return true; + } + + void NetworkBase::ServerHandleChat(NetworkConnection& connection, NetworkPacket& packet) + { + auto szText = packet.ReadString(); + if (szText.empty()) + return; + + if (connection.Player != nullptr) + { + NetworkGroup* group = GetGroupByID(connection.Player->Group); + if (group == nullptr || !group->CanPerformAction(NetworkPermission::Chat)) + { + return; + } + } + + std::string text(szText); + if (connection.Player != nullptr) + { + if (!ProcessChatMessagePluginHooks(connection.Player->Id, text)) + { + // Message not to be relayed + return; + } + } + + const char* formatted = FormatChat(connection.Player, text.c_str()); + ChatAddHistory(formatted); + ServerSendChat(formatted); + } + + void NetworkBase::Client_Handle_GAME_ACTION([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t tick; + GameCommand actionType; + packet >> tick >> actionType; + + MemoryStream stream; + const size_t size = packet.Header.Size - packet.BytesRead; + stream.WriteArray(packet.Read(size), size); + stream.SetPosition(0); + + DataSerialiser ds(false, stream); + + GameActions::GameAction::Ptr action = GameActions::Create(actionType); + if (action == nullptr) + { + LOG_ERROR("Received unregistered game action type: 0x%08X", actionType); + return; + } + action->Serialise(ds); + + if (player_id == action->GetPlayer().id) + { + // Only execute callbacks that belong to us, + // clients can have identical network ids assigned. + auto itr = _gameActionCallbacks.find(action->GetNetworkId()); + if (itr != _gameActionCallbacks.end()) + { + action->SetCallback(itr->second); + _gameActionCallbacks.erase(itr); + } + } + + GameActions::Enqueue(std::move(action), tick); + } + + void NetworkBase::ServerHandleGameAction(NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t tick; + GameCommand actionType; + + NetworkPlayer* player = connection.Player; + if (player == nullptr) + { + return; + } + + packet >> tick >> actionType; + + // Don't let clients send pause or quit + if (actionType == GameCommand::TogglePause || actionType == GameCommand::LoadOrQuit) + { + return; + } + + if (actionType != GameCommand::Custom) + { + // Check if player's group permission allows command to run + NetworkGroup* group = GetGroupByID(connection.Player->Group); + if (group == nullptr || group->CanPerformCommand(actionType) == false) + { + ServerSendShowError(connection, STR_CANT_DO_THIS, STR_PERMISSION_DENIED); + return; + } + } + + // Create and enqueue the action. + GameActions::GameAction::Ptr ga = GameActions::Create(actionType); + if (ga == nullptr) + { + LOG_ERROR( + "Received unregistered game action type: 0x%08X from player: (%d) %s", actionType, connection.Player->Id, + connection.Player->Name.c_str()); + return; + } + + // Player who is hosting is not affected by cooldowns. + if ((player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) == 0) + { + auto cooldownIt = player->CooldownTime.find(actionType); + if (cooldownIt != std::end(player->CooldownTime)) + { + if (cooldownIt->second > 0) + { + ServerSendShowError(connection, STR_CANT_DO_THIS, STR_NETWORK_ACTION_RATE_LIMIT_MESSAGE); + return; + } + } + + uint32_t cooldownTime = ga->GetCooldownTime(); + if (cooldownTime > 0) + { + player->CooldownTime[actionType] = cooldownTime; + } + } + + DataSerialiser stream(false); + const size_t size = packet.Header.Size - packet.BytesRead; + stream.GetStream().WriteArray(packet.Read(size), size); + stream.GetStream().SetPosition(0); + + ga->Serialise(stream); + // Set player to sender, should be 0 if sent from client. + ga->SetPlayer(NetworkPlayerId_t{ connection.Player->Id }); + + GameActions::Enqueue(std::move(ga), tick); + } + + void NetworkBase::Client_Handle_TICK([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t srand0; + uint32_t flags; + uint32_t serverTick; + + packet >> serverTick >> srand0 >> flags; + + ServerTickData tickData; + tickData.srand0 = srand0; + tickData.tick = serverTick; + + if (flags & NETWORK_TICK_FLAG_CHECKSUMS) + { + auto text = packet.ReadString(); + if (!text.empty()) + { + tickData.spriteHash = text; + } + } + + // Don't let the history grow too much. + while (_serverTickData.size() >= 100) + { + _serverTickData.erase(_serverTickData.begin()); + } + + _serverState.tick = serverTick; + _serverTickData.emplace(serverTick, tickData); + } + + void NetworkBase::Client_Handle_PLAYERINFO([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t tick; + packet >> tick; + + NetworkPlayer playerInfo; + playerInfo.Read(packet); + + _pendingPlayerInfo.emplace(tick, playerInfo); + } + + void NetworkBase::Client_Handle_PLAYERLIST([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint32_t tick; + uint8_t size; + packet >> tick >> size; + + auto& pending = _pendingPlayerLists[tick]; + pending.players.clear(); + + for (uint32_t i = 0; i < size; i++) + { + NetworkPlayer tempplayer; + tempplayer.Read(packet); + + pending.players.push_back(std::move(tempplayer)); + } + } + + void NetworkBase::Client_Handle_PING([[maybe_unused]] NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) + { + Client_Send_PING(); + } + + void NetworkBase::ServerHandlePing(NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) + { + int32_t ping = Platform::GetTicks() - connection.PingTime; + if (ping < 0) + { + ping = 0; + } + if (connection.Player != nullptr) + { + connection.Player->Ping = ping; + auto* windowMgr = Ui::GetWindowManager(); + windowMgr->InvalidateByNumber(WindowClass::Player, connection.Player->Id); + } + } + + void NetworkBase::Client_Handle_PINGLIST([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint8_t size; + packet >> size; + for (uint32_t i = 0; i < size; i++) + { + uint8_t id; + uint16_t ping; + packet >> id >> ping; + NetworkPlayer* player = GetPlayerByID(id); + if (player != nullptr) + { + player->Ping = ping; + } + } + + auto* windowMgr = Ui::GetWindowManager(); + windowMgr->InvalidateByClass(WindowClass::Player); + } + + void NetworkBase::Client_Handle_SETDISCONNECTMSG(NetworkConnection& connection, NetworkPacket& packet) + { + auto disconnectmsg = packet.ReadString(); + if (!disconnectmsg.empty()) + { + connection.SetLastDisconnectReason(disconnectmsg); + } + } + + void NetworkBase::ServerHandleGameInfo(NetworkConnection& connection, [[maybe_unused]] NetworkPacket& packet) + { + ServerSendGameInfo(connection); + } + + void NetworkBase::Client_Handle_SHOWERROR([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + StringId title, message; + packet >> title >> message; + ContextShowError(title, message, {}); + } + + void NetworkBase::Client_Handle_GROUPLIST([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + group_list.clear(); + uint8_t size; + packet >> size >> default_group; + for (uint32_t i = 0; i < size; i++) + { + NetworkGroup group; + group.Read(packet); + auto newgroup = std::make_unique(group); + group_list.push_back(std::move(newgroup)); + } + } + + void NetworkBase::Client_Handle_EVENT([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + uint16_t eventType; + packet >> eventType; + switch (eventType) + { + case SERVER_EVENT_PLAYER_JOINED: + { + auto playerName = packet.ReadString(); + auto message = FormatStringID(STR_MULTIPLAYER_PLAYER_HAS_JOINED_THE_GAME, playerName); + ChatAddHistory(message); + break; + } + case SERVER_EVENT_PLAYER_DISCONNECTED: + { + auto playerName = packet.ReadString(); + auto reason = packet.ReadString(); + std::string message; + if (reason.empty()) + { + message = FormatStringID(STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_NO_REASON, playerName); + } + else + { + message = FormatStringID(STR_MULTIPLAYER_PLAYER_HAS_DISCONNECTED_WITH_REASON, playerName, reason); + } + ChatAddHistory(message); + break; + } + } + } + + void NetworkBase::Client_Send_GAMEINFO() + { + LOG_VERBOSE("requesting gameinfo"); + NetworkPacket packet(NetworkCommand::GameInfo); + _serverConnection->QueuePacket(std::move(packet)); + } + + void NetworkBase::Client_Handle_GAMEINFO([[maybe_unused]] NetworkConnection& connection, NetworkPacket& packet) + { + auto jsonString = packet.ReadString(); + packet >> _serverState.gamestateSnapshotsEnabled; + packet >> IsServerPlayerInvisible; + + json_t jsonData = Json::FromString(jsonString); + + if (jsonData.is_object()) + { + ServerName = Json::GetString(jsonData["name"]); + ServerDescription = Json::GetString(jsonData["description"]); + ServerGreeting = Json::GetString(jsonData["greeting"]); + + json_t jsonProvider = jsonData["provider"]; + if (jsonProvider.is_object()) + { + ServerProviderName = Json::GetString(jsonProvider["name"]); + ServerProviderEmail = Json::GetString(jsonProvider["email"]); + ServerProviderWebsite = Json::GetString(jsonProvider["website"]); + } + } + + NetworkChatShowServerGreeting(); + } + + void NetworkReconnect() + { + GetContext()->GetNetwork().Reconnect(); + } + + void NetworkShutdownClient() + { + GetContext()->GetNetwork().ServerClientDisconnected(); + } + + int32_t NetworkBeginClient(const std::string& host, int32_t port) + { + return GetContext()->GetNetwork().BeginClient(host, port); + } + + int32_t NetworkBeginServer(int32_t port, const std::string& address) + { + return GetContext()->GetNetwork().BeginServer(port, address); + } + + void NetworkUpdate() + { + GetContext()->GetNetwork().Update(); + } + + void NetworkProcessPending() + { + GetContext()->GetNetwork().ProcessPending(); + } + + void NetworkFlush() + { + GetContext()->GetNetwork().Flush(); + } + + int32_t NetworkGetMode() + { + return GetContext()->GetNetwork().GetMode(); + } + + int32_t NetworkGetStatus() + { + return GetContext()->GetNetwork().GetStatus(); + } + + bool NetworkIsDesynchronised() + { + return GetContext()->GetNetwork().IsDesynchronised(); + } + + bool NetworkCheckDesynchronisation() + { + return GetContext()->GetNetwork().CheckDesynchronizaton(); + } + + void NetworkRequestGamestateSnapshot() + { + return GetContext()->GetNetwork().RequestStateSnapshot(); + } + + void NetworkSendTick() + { + GetContext()->GetNetwork().ServerSendTick(); + } + + NetworkAuth NetworkGetAuthstatus() + { + return GetContext()->GetNetwork().GetAuthStatus(); + } + + uint32_t NetworkGetServerTick() + { + return GetContext()->GetNetwork().GetServerTick(); + } + + uint8_t NetworkGetCurrentPlayerId() + { + return GetContext()->GetNetwork().GetPlayerID(); + } + + int32_t NetworkGetNumPlayers() + { + return GetContext()->GetNetwork().GetTotalNumPlayers(); + } + + int32_t NetworkGetNumVisiblePlayers() + { + return GetContext()->GetNetwork().GetNumVisiblePlayers(); + } + + const char* NetworkGetPlayerName(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + return static_cast(network.player_list[index]->Name.c_str()); + } + + uint32_t NetworkGetPlayerFlags(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + return network.player_list[index]->Flags; + } + + int32_t NetworkGetPlayerPing(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + return network.player_list[index]->Ping; + } + + int32_t NetworkGetPlayerID(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + return network.player_list[index]->Id; + } + + money64 NetworkGetPlayerMoneySpent(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + return network.player_list[index]->MoneySpent; + } + + std::string NetworkGetPlayerIPAddress(uint32_t id) + { + auto& network = GetContext()->GetNetwork(); + auto conn = network.GetPlayerConnection(id); + if (conn != nullptr && conn->Socket != nullptr) + { + return conn->Socket->GetIpAddress(); + } + return {}; + } + + std::string NetworkGetPlayerPublicKeyHash(uint32_t id) + { + auto& network = GetContext()->GetNetwork(); + auto player = network.GetPlayerByID(id); + if (player != nullptr) + { + return player->KeyHash; + } + return {}; + } + + void NetworkIncrementPlayerNumCommands(uint32_t playerIndex) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(playerIndex, network.player_list); + + network.player_list[playerIndex]->IncrementNumCommands(); + } + + void NetworkAddPlayerMoneySpent(uint32_t index, money64 cost) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + network.player_list[index]->AddMoneySpent(cost); + } + + int32_t NetworkGetPlayerLastAction(uint32_t index, int32_t time) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + if (time && Platform::GetTicks() > network.player_list[index]->LastActionTime + time) + { + return -999; + } + return network.player_list[index]->LastAction; + } + + void NetworkSetPlayerLastAction(uint32_t index, GameCommand command) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + network.player_list[index]->LastAction = static_cast(NetworkActions::FindCommand(command)); + network.player_list[index]->LastActionTime = Platform::GetTicks(); + } + + CoordsXYZ NetworkGetPlayerLastActionCoord(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, GetContext()->GetNetwork().player_list); + + return network.player_list[index]->LastActionCoord; + } + + void NetworkSetPlayerLastActionCoord(uint32_t index, const CoordsXYZ& coord) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + if (index < network.player_list.size()) + { + network.player_list[index]->LastActionCoord = coord; + } + } + + uint32_t NetworkGetPlayerCommandsRan(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, GetContext()->GetNetwork().player_list); + + return network.player_list[index]->CommandsRan; + } + + int32_t NetworkGetPlayerIndex(uint32_t id) + { + auto& network = GetContext()->GetNetwork(); + auto it = network.GetPlayerIteratorByID(id); + if (it == network.player_list.end()) + { + return -1; + } + return static_cast(network.GetPlayerIteratorByID(id) - network.player_list.begin()); + } + + uint8_t NetworkGetPlayerGroup(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + + return network.player_list[index]->Group; + } + + void NetworkSetPlayerGroup(uint32_t index, uint32_t groupindex) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.player_list); + Guard::IndexInRange(groupindex, network.group_list); + + network.player_list[index]->Group = network.group_list[groupindex]->Id; + } + + int32_t NetworkGetGroupIndex(uint8_t id) + { + auto& network = GetContext()->GetNetwork(); + auto it = network.GetGroupIteratorByID(id); + if (it == network.group_list.end()) + { + return -1; + } + return static_cast(network.GetGroupIteratorByID(id) - network.group_list.begin()); + } + + uint8_t NetworkGetGroupID(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(index, network.group_list); + + return network.group_list[index]->Id; + } + + int32_t NetworkGetNumGroups() + { + auto& network = GetContext()->GetNetwork(); + return static_cast(network.group_list.size()); + } + + const char* NetworkGetGroupName(uint32_t index) + { + auto& network = GetContext()->GetNetwork(); + return network.group_list[index]->GetName().c_str(); + } + + void NetworkChatShowConnectedMessage() + { + auto windowManager = Ui::GetWindowManager(); + std::string s = windowManager->GetKeyboardShortcutString("interface.misc.multiplayer_chat"); + const char* sptr = s.c_str(); + + utf8 buffer[256]; + FormatStringLegacy(buffer, sizeof(buffer), STR_MULTIPLAYER_CONNECTED_CHAT_HINT, &sptr); + + NetworkPlayer server; + server.Name = "Server"; + const char* formatted = NetworkBase::FormatChat(&server, buffer); + ChatAddHistory(formatted); + } + + // Display server greeting if one exists + void NetworkChatShowServerGreeting() + { + const auto& greeting = NetworkGetServerGreeting(); + if (!greeting.empty()) + { + thread_local std::string greeting_formatted; + greeting_formatted.assign("{OUTLINE}{GREEN}"); + greeting_formatted += greeting; + ChatAddHistory(greeting_formatted); + } + } + + GameActions::Result NetworkSetPlayerGroup( + NetworkPlayerId_t actionPlayerId, NetworkPlayerId_t playerId, uint8_t groupId, bool isExecuting) + { + auto& network = GetContext()->GetNetwork(); + NetworkPlayer* player = network.GetPlayerByID(playerId); + + NetworkGroup* fromgroup = network.GetGroupByID(actionPlayerId); + if (player == nullptr) + { + return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_DO_THIS, kStringIdNone); + } + + if (network.GetGroupByID(groupId) == nullptr) + { + return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_DO_THIS, kStringIdNone); + } + + if (player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) + { return GameActions::Result( - GameActions::Status::InvalidParameters, STR_ERR_INVALID_PARAMETER, STR_ERR_VALUE_OUT_OF_RANGE); - } + GameActions::Status::InvalidParameters, STR_CANT_CHANGE_GROUP_THAT_THE_HOST_BELONGS_TO, kStringIdNone); + } - network.SaveGroups(); - - return GameActions::Result(); -} - -GameActions::Result NetworkKickPlayer(NetworkPlayerId_t playerId, bool isExecuting) -{ - auto& network = GetContext()->GetNetwork(); - NetworkPlayer* player = network.GetPlayerByID(playerId); - if (player == nullptr) - { - // Player might be already removed by the PLAYERLIST command, need to refactor non-game commands executing too - // early. - return GameActions::Result(GameActions::Status::InvalidParameters, STR_ERR_INVALID_PARAMETER, STR_ERR_PLAYER_NOT_FOUND); - } - - if (player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) - { - return GameActions::Result(GameActions::Status::Disallowed, STR_CANT_KICK_THE_HOST, kStringIdNone); - } - - if (isExecuting) - { - if (network.GetMode() == NETWORK_MODE_SERVER) + if (groupId == 0 && fromgroup != nullptr && fromgroup->Id != 0) { - network.KickPlayer(playerId); + return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_SET_TO_THIS_GROUP, kStringIdNone); + } - NetworkUserManager& networkUserManager = network._userManager; - networkUserManager.Load(); - networkUserManager.RemoveUser(player->KeyHash); - networkUserManager.Save(); + if (isExecuting) + { + player->Group = groupId; + + if (NetworkGetMode() == NETWORK_MODE_SERVER) + { + // Add or update saved user + NetworkUserManager& userManager = network._userManager; + NetworkUser* networkUser = userManager.GetOrAddUser(player->KeyHash); + networkUser->GroupId = groupId; + networkUser->Name = player->Name; + userManager.Save(); + } + + auto* windowMgr = Ui::GetWindowManager(); + windowMgr->InvalidateByNumber(WindowClass::Player, playerId); + + // Log set player group event + NetworkPlayer* game_command_player = network.GetPlayerByID(actionPlayerId); + NetworkGroup* new_player_group = network.GetGroupByID(groupId); + char log_msg[256]; + const char* args[3] = { + player->Name.c_str(), + new_player_group->GetName().c_str(), + game_command_player->Name.c_str(), + }; + FormatStringLegacy(log_msg, 256, STR_LOG_SET_PLAYER_GROUP, args); + NetworkAppendServerLog(log_msg); + } + return GameActions::Result(); + } + + GameActions::Result NetworkModifyGroups( + NetworkPlayerId_t actionPlayerId, GameActions::ModifyGroupType type, uint8_t groupId, const std::string& name, + uint32_t permissionIndex, GameActions::PermissionState permissionState, bool isExecuting) + { + auto& network = GetContext()->GetNetwork(); + switch (type) + { + case GameActions::ModifyGroupType::AddGroup: + { + if (isExecuting) + { + NetworkGroup* newgroup = network.AddGroup(); + if (newgroup == nullptr) + { + return GameActions::Result(GameActions::Status::Unknown, STR_CANT_DO_THIS, kStringIdNone); + } + } + } + break; + case GameActions::ModifyGroupType::RemoveGroup: + { + if (groupId == 0) + { + return GameActions::Result( + GameActions::Status::Disallowed, STR_THIS_GROUP_CANNOT_BE_MODIFIED, kStringIdNone); + } + for (const auto& it : network.player_list) + { + if ((it.get())->Group == groupId) + { + return GameActions::Result( + GameActions::Status::Disallowed, STR_CANT_REMOVE_GROUP_THAT_PLAYERS_BELONG_TO, kStringIdNone); + } + } + if (isExecuting) + { + network.RemoveGroup(groupId); + } + } + break; + case GameActions::ModifyGroupType::SetPermissions: + { + if (groupId == 0) + { // can't change admin group permissions + return GameActions::Result( + GameActions::Status::Disallowed, STR_THIS_GROUP_CANNOT_BE_MODIFIED, kStringIdNone); + } + NetworkGroup* mygroup = nullptr; + NetworkPlayer* player = network.GetPlayerByID(actionPlayerId); + auto networkPermission = static_cast(permissionIndex); + if (player != nullptr && permissionState == GameActions::PermissionState::Toggle) + { + mygroup = network.GetGroupByID(player->Group); + if (mygroup == nullptr || !mygroup->CanPerformAction(networkPermission)) + { + return GameActions::Result( + GameActions::Status::Disallowed, STR_CANT_MODIFY_PERMISSION_THAT_YOU_DO_NOT_HAVE_YOURSELF, + kStringIdNone); + } + } + if (isExecuting) + { + NetworkGroup* group = network.GetGroupByID(groupId); + if (group != nullptr) + { + if (permissionState != GameActions::PermissionState::Toggle) + { + if (mygroup != nullptr) + { + if (permissionState == GameActions::PermissionState::SetAll) + { + group->ActionsAllowed = mygroup->ActionsAllowed; + } + else + { + group->ActionsAllowed.fill(0x00); + } + } + } + else + { + group->ToggleActionPermission(networkPermission); + } + } + } + } + break; + case GameActions::ModifyGroupType::SetName: + { + NetworkGroup* group = network.GetGroupByID(groupId); + if (group == nullptr) + { + return GameActions::Result(GameActions::Status::InvalidParameters, STR_CANT_RENAME_GROUP, kStringIdNone); + } + + const char* oldName = group->GetName().c_str(); + + if (strcmp(oldName, name.c_str()) == 0) + { + return GameActions::Result(); + } + + if (name.empty()) + { + return GameActions::Result( + GameActions::Status::InvalidParameters, STR_CANT_RENAME_GROUP, STR_INVALID_GROUP_NAME); + } + + if (isExecuting) + { + if (group != nullptr) + { + group->SetName(name); + } + } + } + break; + case GameActions::ModifyGroupType::SetDefault: + { + if (groupId == 0) + { + return GameActions::Result(GameActions::Status::Disallowed, STR_CANT_SET_TO_THIS_GROUP, kStringIdNone); + } + if (isExecuting) + { + network.SetDefaultGroup(groupId); + } + } + break; + default: + LOG_ERROR("Invalid Modify Group Type: %u", static_cast(type)); + return GameActions::Result( + GameActions::Status::InvalidParameters, STR_ERR_INVALID_PARAMETER, STR_ERR_VALUE_OUT_OF_RANGE); + } + + network.SaveGroups(); + + return GameActions::Result(); + } + + GameActions::Result NetworkKickPlayer(NetworkPlayerId_t playerId, bool isExecuting) + { + auto& network = GetContext()->GetNetwork(); + NetworkPlayer* player = network.GetPlayerByID(playerId); + if (player == nullptr) + { + // Player might be already removed by the PLAYERLIST command, need to refactor non-game commands executing too + // early. + return GameActions::Result( + GameActions::Status::InvalidParameters, STR_ERR_INVALID_PARAMETER, STR_ERR_PLAYER_NOT_FOUND); + } + + if (player->Flags & NETWORK_PLAYER_FLAG_ISSERVER) + { + return GameActions::Result(GameActions::Status::Disallowed, STR_CANT_KICK_THE_HOST, kStringIdNone); + } + + if (isExecuting) + { + if (network.GetMode() == NETWORK_MODE_SERVER) + { + network.KickPlayer(playerId); + + NetworkUserManager& networkUserManager = network._userManager; + networkUserManager.Load(); + networkUserManager.RemoveUser(player->KeyHash); + networkUserManager.Save(); + } + } + return GameActions::Result(); + } + + uint8_t NetworkGetDefaultGroup() + { + auto& network = GetContext()->GetNetwork(); + return network.GetDefaultGroup(); + } + + int32_t NetworkGetNumActions() + { + return static_cast(NetworkActions::Actions.size()); + } + + StringId NetworkGetActionNameStringID(uint32_t index) + { + if (index < NetworkActions::Actions.size()) + { + return NetworkActions::Actions[index].Name; + } + + return kStringIdNone; + } + + int32_t NetworkCanPerformAction(uint32_t groupindex, NetworkPermission index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(groupindex, network.group_list); + + return network.group_list[groupindex]->CanPerformAction(index); + } + + int32_t NetworkCanPerformCommand(uint32_t groupindex, int32_t index) + { + auto& network = GetContext()->GetNetwork(); + Guard::IndexInRange(groupindex, network.group_list); + + return network.group_list[groupindex]->CanPerformCommand(static_cast(index)); // TODO + } + + void NetworkSetPickupPeep(uint8_t playerid, Peep* peep) + { + auto& network = GetContext()->GetNetwork(); + if (network.GetMode() == NETWORK_MODE_NONE) + { + _pickup_peep = peep; + } + else + { + NetworkPlayer* player = network.GetPlayerByID(playerid); + if (player != nullptr) + { + player->PickupPeep = peep; + } } } - return GameActions::Result(); -} -uint8_t NetworkGetDefaultGroup() -{ - auto& network = GetContext()->GetNetwork(); - return network.GetDefaultGroup(); -} - -int32_t NetworkGetNumActions() -{ - return static_cast(NetworkActions::Actions.size()); -} - -StringId NetworkGetActionNameStringID(uint32_t index) -{ - if (index < NetworkActions::Actions.size()) + Peep* NetworkGetPickupPeep(uint8_t playerid) { - return NetworkActions::Actions[index].Name; + auto& network = GetContext()->GetNetwork(); + if (network.GetMode() == NETWORK_MODE_NONE) + { + return _pickup_peep; + } + + NetworkPlayer* player = network.GetPlayerByID(playerid); + if (player != nullptr) + { + return player->PickupPeep; + } + return nullptr; } - return kStringIdNone; -} + void NetworkSetPickupPeepOldX(uint8_t playerid, int32_t x) + { + auto& network = GetContext()->GetNetwork(); + if (network.GetMode() == NETWORK_MODE_NONE) + { + _pickup_peep_old_x = x; + } + else + { + NetworkPlayer* player = network.GetPlayerByID(playerid); + if (player != nullptr) + { + player->PickupPeepOldX = x; + } + } + } -int32_t NetworkCanPerformAction(uint32_t groupindex, NetworkPermission index) + int32_t NetworkGetPickupPeepOldX(uint8_t playerid) + { + auto& network = GetContext()->GetNetwork(); + if (network.GetMode() == NETWORK_MODE_NONE) + { + return _pickup_peep_old_x; + } + + NetworkPlayer* player = network.GetPlayerByID(playerid); + if (player != nullptr) + { + return player->PickupPeepOldX; + } + return -1; + } + + bool NetworkIsServerPlayerInvisible() + { + return GetContext()->GetNetwork().IsServerPlayerInvisible; + } + + int32_t NetworkGetCurrentPlayerGroupIndex() + { + auto& network = GetContext()->GetNetwork(); + NetworkPlayer* player = network.GetPlayerByID(network.GetPlayerID()); + if (player != nullptr) + { + return NetworkGetGroupIndex(player->Group); + } + return -1; + } + + void NetworkSendChat(const char* text, const std::vector& playerIds) + { + auto& network = GetContext()->GetNetwork(); + if (network.GetMode() == NETWORK_MODE_CLIENT) + { + network.Client_Send_CHAT(text); + } + else if (network.GetMode() == NETWORK_MODE_SERVER) + { + std::string message = text; + if (ProcessChatMessagePluginHooks(network.GetPlayerID(), message)) + { + auto player = network.GetPlayerByID(network.GetPlayerID()); + if (player != nullptr) + { + auto formatted = network.FormatChat(player, message.c_str()); + if (playerIds.empty() + || std::find(playerIds.begin(), playerIds.end(), network.GetPlayerID()) != playerIds.end()) + { + // Server is one of the recipients + ChatAddHistory(formatted); + } + network.ServerSendChat(formatted, playerIds); + } + } + } + } + + void NetworkSendGameAction(const GameActions::GameAction* action) + { + auto& network = GetContext()->GetNetwork(); + switch (network.GetMode()) + { + case NETWORK_MODE_SERVER: + network.ServerSendGameAction(action); + break; + case NETWORK_MODE_CLIENT: + network.Client_Send_GAME_ACTION(action); + break; + } + } + + void NetworkSendPassword(const std::string& password) + { + auto& network = GetContext()->GetNetwork(); + const auto keyPath = NetworkGetPrivateKeyPath(Config::Get().network.PlayerName); + if (!File::Exists(keyPath)) + { + LOG_ERROR("Private key %s missing! Restart the game to generate it.", keyPath.c_str()); + return; + } + try + { + auto fs = FileStream(keyPath, FileMode::open); + network._key.LoadPrivate(&fs); + } + catch (const std::exception&) + { + LOG_ERROR("Error reading private key from %s.", keyPath.c_str()); + return; + } + const std::string pubkey = network._key.PublicKeyString(); + + std::vector signature; + network._key.Sign(network._challenge.data(), network._challenge.size(), signature); + // Don't keep private key in memory. There's no need and it may get leaked + // when process dump gets collected at some point in future. + network._key.Unload(); + network.Client_Send_AUTH(Config::Get().network.PlayerName, password, pubkey, signature); + } + + void NetworkSetPassword(const char* password) + { + auto& network = GetContext()->GetNetwork(); + network.SetPassword(password); + } + + void NetworkAppendChatLog(std::string_view text) + { + auto& network = GetContext()->GetNetwork(); + network.AppendChatLog(text); + } + + void NetworkAppendServerLog(const utf8* text) + { + auto& network = GetContext()->GetNetwork(); + network.AppendServerLog(text); + } + + static u8string NetworkGetKeysDirectory() + { + auto& env = GetContext()->GetPlatformEnvironment(); + return Path::Combine(env.GetDirectoryPath(DirBase::user), u8"keys"); + } + + static u8string NetworkGetPrivateKeyPath(u8string_view playerName) + { + return Path::Combine(NetworkGetKeysDirectory(), u8string(playerName) + u8".privkey"); + } + + static u8string NetworkGetPublicKeyPath(u8string_view playerName, u8string_view hash) + { + const auto filename = u8string(playerName) + u8"-" + u8string(hash) + u8".pubkey"; + return Path::Combine(NetworkGetKeysDirectory(), filename); + } + + u8string NetworkGetServerName() + { + auto& network = GetContext()->GetNetwork(); + return network.ServerName; + } + u8string NetworkGetServerDescription() + { + auto& network = GetContext()->GetNetwork(); + return network.ServerDescription; + } + u8string NetworkGetServerGreeting() + { + auto& network = GetContext()->GetNetwork(); + return network.ServerGreeting; + } + u8string NetworkGetServerProviderName() + { + auto& network = GetContext()->GetNetwork(); + return network.ServerProviderName; + } + u8string NetworkGetServerProviderEmail() + { + auto& network = GetContext()->GetNetwork(); + return network.ServerProviderEmail; + } + u8string NetworkGetServerProviderWebsite() + { + auto& network = GetContext()->GetNetwork(); + return network.ServerProviderWebsite; + } + + std::string NetworkGetVersion() + { + return kNetworkStreamID; + } + + NetworkStats NetworkGetStats() + { + auto& network = GetContext()->GetNetwork(); + return network.GetStats(); + } + + NetworkServerState NetworkGetServerState() + { + auto& network = GetContext()->GetNetwork(); + return network.GetServerState(); + } + + bool NetworkGamestateSnapshotsEnabled() + { + return NetworkGetServerState().gamestateSnapshotsEnabled; + } + + json_t NetworkGetServerInfoAsJson() + { + auto& network = GetContext()->GetNetwork(); + return network.GetServerInfoAsJson(); + } + +} // namespace OpenRCT2::Network + +#else // DISABLE_NETWORK + +namespace OpenRCT2::Network { - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(groupindex, network.group_list); + int32_t NetworkGetMode() + { + return NETWORK_MODE_NONE; + } + int32_t NetworkGetStatus() + { + return NETWORK_STATUS_NONE; + } + NetworkAuth NetworkGetAuthstatus() + { + return NetworkAuth::None; + } + uint32_t NetworkGetServerTick() + { + return getGameState().currentTicks; + } + void NetworkFlush() + { + } + void NetworkSendTick() + { + } + bool NetworkIsDesynchronised() + { + return false; + } + bool NetworkGamestateSnapshotsEnabled() + { + return false; + } + bool NetworkCheckDesynchronisation() + { + return false; + } + void NetworkRequestGamestateSnapshot() + { + } + void NetworkSendGameAction(const GameActions::GameAction* action) + { + } + void NetworkUpdate() + { + } + void NetworkProcessPending() + { + } + int32_t NetworkBeginClient(const std::string& host, int32_t port) + { + return 1; + } + int32_t NetworkBeginServer(int32_t port, const std::string& address) + { + return 1; + } + int32_t NetworkGetNumPlayers() + { + return 1; + } + int32_t NetworkGetNumVisiblePlayers() + { + return 1; + } + const char* NetworkGetPlayerName(uint32_t index) + { + return "local (OpenRCT2 compiled without MP)"; + } + uint32_t NetworkGetPlayerFlags(uint32_t index) + { + return 0; + } + int32_t NetworkGetPlayerPing(uint32_t index) + { + return 0; + } + int32_t NetworkGetPlayerID(uint32_t index) + { + return 0; + } + money64 NetworkGetPlayerMoneySpent(uint32_t index) + { + return 0.00_GBP; + } + std::string NetworkGetPlayerIPAddress(uint32_t id) + { + return {}; + } + std::string NetworkGetPlayerPublicKeyHash(uint32_t id) + { + return {}; + } + void NetworkIncrementPlayerNumCommands(uint32_t playerIndex) + { + } + void NetworkAddPlayerMoneySpent(uint32_t index, money64 cost) + { + } + int32_t NetworkGetPlayerLastAction(uint32_t index, int32_t time) + { + return -999; + } + void NetworkSetPlayerLastAction(uint32_t index, GameCommand command) + { + } + CoordsXYZ NetworkGetPlayerLastActionCoord(uint32_t index) + { + return { 0, 0, 0 }; + } + void NetworkSetPlayerLastActionCoord(uint32_t index, const CoordsXYZ& coord) + { + } + uint32_t NetworkGetPlayerCommandsRan(uint32_t index) + { + return 0; + } + int32_t NetworkGetPlayerIndex(uint32_t id) + { + return -1; + } + uint8_t NetworkGetPlayerGroup(uint32_t index) + { + return 0; + } + void NetworkSetPlayerGroup(uint32_t index, uint32_t groupindex) + { + } + int32_t NetworkGetGroupIndex(uint8_t id) + { + return -1; + } + uint8_t NetworkGetGroupID(uint32_t index) + { + return 0; + } + int32_t NetworkGetNumGroups() + { + return 0; + } + const char* NetworkGetGroupName(uint32_t index) + { + return ""; + }; - return network.group_list[groupindex]->CanPerformAction(index); -} - -int32_t NetworkCanPerformCommand(uint32_t groupindex, int32_t index) -{ - auto& network = GetContext()->GetNetwork(); - Guard::IndexInRange(groupindex, network.group_list); - - return network.group_list[groupindex]->CanPerformCommand(static_cast(index)); // TODO -} - -void NetworkSetPickupPeep(uint8_t playerid, Peep* peep) -{ - auto& network = GetContext()->GetNetwork(); - if (network.GetMode() == NETWORK_MODE_NONE) + GameActions::Result NetworkSetPlayerGroup( + NetworkPlayerId_t actionPlayerId, NetworkPlayerId_t playerId, uint8_t groupId, bool isExecuting) + { + return GameActions::Result(); + } + GameActions::Result NetworkModifyGroups( + NetworkPlayerId_t actionPlayerId, GameActions::ModifyGroupType type, uint8_t groupId, const std::string& name, + uint32_t permissionIndex, GameActions::PermissionState permissionState, bool isExecuting) + { + return GameActions::Result(); + } + GameActions::Result NetworkKickPlayer(NetworkPlayerId_t playerId, bool isExecuting) + { + return GameActions::Result(); + } + uint8_t NetworkGetDefaultGroup() + { + return 0; + } + int32_t NetworkGetNumActions() + { + return 0; + } + StringId NetworkGetActionNameStringID(uint32_t index) + { + return -1; + } + int32_t NetworkCanPerformAction(uint32_t groupindex, NetworkPermission index) + { + return 0; + } + int32_t NetworkCanPerformCommand(uint32_t groupindex, int32_t index) + { + return 0; + } + void NetworkSetPickupPeep(uint8_t playerid, Peep* peep) { _pickup_peep = peep; } - else - { - NetworkPlayer* player = network.GetPlayerByID(playerid); - if (player != nullptr) - { - player->PickupPeep = peep; - } - } -} - -Peep* NetworkGetPickupPeep(uint8_t playerid) -{ - auto& network = GetContext()->GetNetwork(); - if (network.GetMode() == NETWORK_MODE_NONE) + Peep* NetworkGetPickupPeep(uint8_t playerid) { return _pickup_peep; } - - NetworkPlayer* player = network.GetPlayerByID(playerid); - if (player != nullptr) - { - return player->PickupPeep; - } - return nullptr; -} - -void NetworkSetPickupPeepOldX(uint8_t playerid, int32_t x) -{ - auto& network = GetContext()->GetNetwork(); - if (network.GetMode() == NETWORK_MODE_NONE) + void NetworkSetPickupPeepOldX(uint8_t playerid, int32_t x) { _pickup_peep_old_x = x; } - else - { - NetworkPlayer* player = network.GetPlayerByID(playerid); - if (player != nullptr) - { - player->PickupPeepOldX = x; - } - } -} - -int32_t NetworkGetPickupPeepOldX(uint8_t playerid) -{ - auto& network = GetContext()->GetNetwork(); - if (network.GetMode() == NETWORK_MODE_NONE) + int32_t NetworkGetPickupPeepOldX(uint8_t playerid) { return _pickup_peep_old_x; } - - NetworkPlayer* player = network.GetPlayerByID(playerid); - if (player != nullptr) + void NetworkSendChat(const char* text, const std::vector& playerIds) { - return player->PickupPeepOldX; } - return -1; -} - -bool NetworkIsServerPlayerInvisible() -{ - return GetContext()->GetNetwork().IsServerPlayerInvisible; -} - -int32_t NetworkGetCurrentPlayerGroupIndex() -{ - auto& network = GetContext()->GetNetwork(); - NetworkPlayer* player = network.GetPlayerByID(network.GetPlayerID()); - if (player != nullptr) + void NetworkSendPassword(const std::string& password) { - return NetworkGetGroupIndex(player->Group); } - return -1; -} - -void NetworkSendChat(const char* text, const std::vector& playerIds) -{ - auto& network = GetContext()->GetNetwork(); - if (network.GetMode() == NETWORK_MODE_CLIENT) + void NetworkReconnect() { - network.Client_Send_CHAT(text); } - else if (network.GetMode() == NETWORK_MODE_SERVER) + void NetworkShutdownClient() { - std::string message = text; - if (ProcessChatMessagePluginHooks(network.GetPlayerID(), message)) - { - auto player = network.GetPlayerByID(network.GetPlayerID()); - if (player != nullptr) - { - auto formatted = network.FormatChat(player, message.c_str()); - if (playerIds.empty() - || std::find(playerIds.begin(), playerIds.end(), network.GetPlayerID()) != playerIds.end()) - { - // Server is one of the recipients - ChatAddHistory(formatted); - } - network.ServerSendChat(formatted, playerIds); - } - } } -} - -void NetworkSendGameAction(const GameActions::GameAction* action) -{ - auto& network = GetContext()->GetNetwork(); - switch (network.GetMode()) + void NetworkSetPassword(const char* password) { - case NETWORK_MODE_SERVER: - network.ServerSendGameAction(action); - break; - case NETWORK_MODE_CLIENT: - network.Client_Send_GAME_ACTION(action); - break; } -} - -void NetworkSendPassword(const std::string& password) -{ - auto& network = GetContext()->GetNetwork(); - const auto keyPath = NetworkGetPrivateKeyPath(Config::Get().network.PlayerName); - if (!File::Exists(keyPath)) + uint8_t NetworkGetCurrentPlayerId() { - LOG_ERROR("Private key %s missing! Restart the game to generate it.", keyPath.c_str()); - return; + return 0; } - try + int32_t NetworkGetCurrentPlayerGroupIndex() { - auto fs = FileStream(keyPath, FileMode::open); - network._key.LoadPrivate(&fs); + return 0; } - catch (const std::exception&) + bool NetworkIsServerPlayerInvisible() { - LOG_ERROR("Error reading private key from %s.", keyPath.c_str()); - return; + return false; } - const std::string pubkey = network._key.PublicKeyString(); + void NetworkAppendChatLog(std::string_view) + { + } + void NetworkAppendServerLog(const utf8* text) + { + } + u8string NetworkGetServerName() + { + return u8string(); + } + u8string NetworkGetServerDescription() + { + return u8string(); + } + u8string NetworkGetServerGreeting() + { + return u8string(); + } + u8string NetworkGetServerProviderName() + { + return u8string(); + } + u8string NetworkGetServerProviderEmail() + { + return u8string(); + } + u8string NetworkGetServerProviderWebsite() + { + return u8string(); + } + std::string NetworkGetVersion() + { + return "Multiplayer disabled"; + } + NetworkStats NetworkGetStats() + { + return NetworkStats{}; + } + NetworkServerState NetworkGetServerState() + { + return NetworkServerState{}; + } + json_t NetworkGetServerInfoAsJson() + { + return {}; + } +} // namespace OpenRCT2::Network - std::vector signature; - network._key.Sign(network._challenge.data(), network._challenge.size(), signature); - // Don't keep private key in memory. There's no need and it may get leaked - // when process dump gets collected at some point in future. - network._key.Unload(); - network.Client_Send_AUTH(Config::Get().network.PlayerName, password, pubkey, signature); -} - -void NetworkSetPassword(const char* password) -{ - auto& network = GetContext()->GetNetwork(); - network.SetPassword(password); -} - -void NetworkAppendChatLog(std::string_view text) -{ - auto& network = GetContext()->GetNetwork(); - network.AppendChatLog(text); -} - -void NetworkAppendServerLog(const utf8* text) -{ - auto& network = GetContext()->GetNetwork(); - network.AppendServerLog(text); -} - -static u8string NetworkGetKeysDirectory() -{ - auto& env = GetContext()->GetPlatformEnvironment(); - return Path::Combine(env.GetDirectoryPath(DirBase::user), u8"keys"); -} - -static u8string NetworkGetPrivateKeyPath(u8string_view playerName) -{ - return Path::Combine(NetworkGetKeysDirectory(), u8string(playerName) + u8".privkey"); -} - -static u8string NetworkGetPublicKeyPath(u8string_view playerName, u8string_view hash) -{ - const auto filename = u8string(playerName) + u8"-" + u8string(hash) + u8".pubkey"; - return Path::Combine(NetworkGetKeysDirectory(), filename); -} - -u8string NetworkGetServerName() -{ - auto& network = GetContext()->GetNetwork(); - return network.ServerName; -} -u8string NetworkGetServerDescription() -{ - auto& network = GetContext()->GetNetwork(); - return network.ServerDescription; -} -u8string NetworkGetServerGreeting() -{ - auto& network = GetContext()->GetNetwork(); - return network.ServerGreeting; -} -u8string NetworkGetServerProviderName() -{ - auto& network = GetContext()->GetNetwork(); - return network.ServerProviderName; -} -u8string NetworkGetServerProviderEmail() -{ - auto& network = GetContext()->GetNetwork(); - return network.ServerProviderEmail; -} -u8string NetworkGetServerProviderWebsite() -{ - auto& network = GetContext()->GetNetwork(); - return network.ServerProviderWebsite; -} - -std::string NetworkGetVersion() -{ - return kNetworkStreamID; -} - -NetworkStats NetworkGetStats() -{ - auto& network = GetContext()->GetNetwork(); - return network.GetStats(); -} - -NetworkServerState NetworkGetServerState() -{ - auto& network = GetContext()->GetNetwork(); - return network.GetServerState(); -} - -bool NetworkGamestateSnapshotsEnabled() -{ - return NetworkGetServerState().gamestateSnapshotsEnabled; -} - -json_t NetworkGetServerInfoAsJson() -{ - auto& network = GetContext()->GetNetwork(); - return network.GetServerInfoAsJson(); -} -#else -int32_t NetworkGetMode() -{ - return NETWORK_MODE_NONE; -} -int32_t NetworkGetStatus() -{ - return NETWORK_STATUS_NONE; -} -NetworkAuth NetworkGetAuthstatus() -{ - return NetworkAuth::None; -} -uint32_t NetworkGetServerTick() -{ - return getGameState().currentTicks; -} -void NetworkFlush() -{ -} -void NetworkSendTick() -{ -} -bool NetworkIsDesynchronised() -{ - return false; -} -bool NetworkGamestateSnapshotsEnabled() -{ - return false; -} -bool NetworkCheckDesynchronisation() -{ - return false; -} -void NetworkRequestGamestateSnapshot() -{ -} -void NetworkSendGameAction(const GameActions::GameAction* action) -{ -} -void NetworkUpdate() -{ -} -void NetworkProcessPending() -{ -} -int32_t NetworkBeginClient(const std::string& host, int32_t port) -{ - return 1; -} -int32_t NetworkBeginServer(int32_t port, const std::string& address) -{ - return 1; -} -int32_t NetworkGetNumPlayers() -{ - return 1; -} -int32_t NetworkGetNumVisiblePlayers() -{ - return 1; -} -const char* NetworkGetPlayerName(uint32_t index) -{ - return "local (OpenRCT2 compiled without MP)"; -} -uint32_t NetworkGetPlayerFlags(uint32_t index) -{ - return 0; -} -int32_t NetworkGetPlayerPing(uint32_t index) -{ - return 0; -} -int32_t NetworkGetPlayerID(uint32_t index) -{ - return 0; -} -money64 NetworkGetPlayerMoneySpent(uint32_t index) -{ - return 0.00_GBP; -} -std::string NetworkGetPlayerIPAddress(uint32_t id) -{ - return {}; -} -std::string NetworkGetPlayerPublicKeyHash(uint32_t id) -{ - return {}; -} -void NetworkIncrementPlayerNumCommands(uint32_t playerIndex) -{ -} -void NetworkAddPlayerMoneySpent(uint32_t index, money64 cost) -{ -} -int32_t NetworkGetPlayerLastAction(uint32_t index, int32_t time) -{ - return -999; -} -void NetworkSetPlayerLastAction(uint32_t index, GameCommand command) -{ -} -CoordsXYZ NetworkGetPlayerLastActionCoord(uint32_t index) -{ - return { 0, 0, 0 }; -} -void NetworkSetPlayerLastActionCoord(uint32_t index, const CoordsXYZ& coord) -{ -} -uint32_t NetworkGetPlayerCommandsRan(uint32_t index) -{ - return 0; -} -int32_t NetworkGetPlayerIndex(uint32_t id) -{ - return -1; -} -uint8_t NetworkGetPlayerGroup(uint32_t index) -{ - return 0; -} -void NetworkSetPlayerGroup(uint32_t index, uint32_t groupindex) -{ -} -int32_t NetworkGetGroupIndex(uint8_t id) -{ - return -1; -} -uint8_t NetworkGetGroupID(uint32_t index) -{ - return 0; -} -int32_t NetworkGetNumGroups() -{ - return 0; -} -const char* NetworkGetGroupName(uint32_t index) -{ - return ""; -}; - -GameActions::Result NetworkSetPlayerGroup( - NetworkPlayerId_t actionPlayerId, NetworkPlayerId_t playerId, uint8_t groupId, bool isExecuting) -{ - return GameActions::Result(); -} -GameActions::Result NetworkModifyGroups( - NetworkPlayerId_t actionPlayerId, GameActions::ModifyGroupType type, uint8_t groupId, const std::string& name, - uint32_t permissionIndex, GameActions::PermissionState permissionState, bool isExecuting) -{ - return GameActions::Result(); -} -GameActions::Result NetworkKickPlayer(NetworkPlayerId_t playerId, bool isExecuting) -{ - return GameActions::Result(); -} -uint8_t NetworkGetDefaultGroup() -{ - return 0; -} -int32_t NetworkGetNumActions() -{ - return 0; -} -StringId NetworkGetActionNameStringID(uint32_t index) -{ - return -1; -} -int32_t NetworkCanPerformAction(uint32_t groupindex, NetworkPermission index) -{ - return 0; -} -int32_t NetworkCanPerformCommand(uint32_t groupindex, int32_t index) -{ - return 0; -} -void NetworkSetPickupPeep(uint8_t playerid, Peep* peep) -{ - _pickup_peep = peep; -} -Peep* NetworkGetPickupPeep(uint8_t playerid) -{ - return _pickup_peep; -} -void NetworkSetPickupPeepOldX(uint8_t playerid, int32_t x) -{ - _pickup_peep_old_x = x; -} -int32_t NetworkGetPickupPeepOldX(uint8_t playerid) -{ - return _pickup_peep_old_x; -} -void NetworkSendChat(const char* text, const std::vector& playerIds) -{ -} -void NetworkSendPassword(const std::string& password) -{ -} -void NetworkReconnect() -{ -} -void NetworkShutdownClient() -{ -} -void NetworkSetPassword(const char* password) -{ -} -uint8_t NetworkGetCurrentPlayerId() -{ - return 0; -} -int32_t NetworkGetCurrentPlayerGroupIndex() -{ - return 0; -} -bool NetworkIsServerPlayerInvisible() -{ - return false; -} -void NetworkAppendChatLog(std::string_view) -{ -} -void NetworkAppendServerLog(const utf8* text) -{ -} -u8string NetworkGetServerName() -{ - return u8string(); -} -u8string NetworkGetServerDescription() -{ - return u8string(); -} -u8string NetworkGetServerGreeting() -{ - return u8string(); -} -u8string NetworkGetServerProviderName() -{ - return u8string(); -} -u8string NetworkGetServerProviderEmail() -{ - return u8string(); -} -u8string NetworkGetServerProviderWebsite() -{ - return u8string(); -} -std::string NetworkGetVersion() -{ - return "Multiplayer disabled"; -} -NetworkStats NetworkGetStats() -{ - return NetworkStats{}; -} -NetworkServerState NetworkGetServerState() -{ - return NetworkServerState{}; -} -json_t NetworkGetServerInfoAsJson() -{ - return {}; -} #endif /* DISABLE_NETWORK */ diff --git a/src/openrct2/network/NetworkBase.h b/src/openrct2/network/NetworkBase.h index 625bb978a2..5ac84e2b7a 100644 --- a/src/openrct2/network/NetworkBase.h +++ b/src/openrct2/network/NetworkBase.h @@ -22,235 +22,239 @@ namespace OpenRCT2 struct IContext; } -class NetworkBase : public OpenRCT2::System +namespace OpenRCT2::Network { -public: - NetworkBase(OpenRCT2::IContext& context); - -public: // Uncategorized - bool BeginServer(uint16_t port, const std::string& address); - bool BeginClient(const std::string& host, uint16_t port); - -public: // Common - bool Init(); - void Close(); - uint32_t GetServerTick() const noexcept; - // FIXME: This is currently the wrong function to override in System, will be refactored later. - void Update() override final; - void Flush(); - void ProcessPending(); - void ProcessPlayerList(); - auto GetPlayerIteratorByID(uint8_t id) const; - auto GetGroupIteratorByID(uint8_t id) const; - NetworkPlayer* GetPlayerByID(uint8_t id) const; - NetworkGroup* GetGroupByID(uint8_t id) const; - int32_t GetTotalNumPlayers() const noexcept; - int32_t GetNumVisiblePlayers() const noexcept; - void SetPassword(u8string_view password); - uint8_t GetDefaultGroup() const noexcept; - std::string BeginLog(const std::string& directory, const std::string& midName, const std::string& filenameFormat); - void AppendLog(std::ostream& fs, std::string_view s); - void BeginChatLog(); - void AppendChatLog(std::string_view s); - void CloseChatLog(); - NetworkStats GetStats() const; - json_t GetServerInfoAsJson() const; - bool ProcessConnection(NetworkConnection& connection); - void CloseConnection(); - NetworkPlayer* AddPlayer(const std::string& name, const std::string& keyhash); - void ProcessPacket(NetworkConnection& connection, NetworkPacket& packet); - -public: // Server - NetworkConnection* GetPlayerConnection(uint8_t id) const; - void KickPlayer(int32_t playerId); - NetworkGroup* AddGroup(); - void LoadGroups(); - void SetDefaultGroup(uint8_t id); - void SaveGroups(); - void RemoveGroup(uint8_t id); - uint8_t GetGroupIDByHash(const std::string& keyhash); - void BeginServerLog(); - void AppendServerLog(const std::string& s); - void CloseServerLog(); - void DecayCooldown(NetworkPlayer* player); - void AddClient(std::unique_ptr&& socket); - std::string GetMasterServerUrl(); - std::string GenerateAdvertiseKey(); - void SetupDefaultGroups(); - void RemovePlayer(std::unique_ptr& connection); - void UpdateServer(); - void ServerClientDisconnected(std::unique_ptr& connection); - bool SaveMap(OpenRCT2::IStream* stream, const std::vector& objects) const; - std::vector SaveForNetwork(const std::vector& objects) const; - std::string MakePlayerNameUnique(const std::string& name); - - // Packet dispatchers. - void ServerSendAuth(NetworkConnection& connection); - void ServerSendToken(NetworkConnection& connection); - void ServerSendMap(NetworkConnection* connection = nullptr); - void ServerSendChat(const char* text, const std::vector& playerIds = {}); - void ServerSendGameAction(const OpenRCT2::GameActions::GameAction* action); - void ServerSendTick(); - void ServerSendPlayerInfo(int32_t playerId); - void ServerSendPlayerList(); - void ServerSendPing(); - void ServerSendPingList(); - void ServerSendSetDisconnectMsg(NetworkConnection& connection, const char* msg); - void ServerSendGameInfo(NetworkConnection& connection); - void ServerSendShowError(NetworkConnection& connection, StringId title, StringId message); - void ServerSendGroupList(NetworkConnection& connection); - void ServerSendEventPlayerJoined(const char* playerName); - void ServerSendEventPlayerDisconnected(const char* playerName, const char* reason); - void ServerSendObjectsList( - NetworkConnection& connection, const std::vector& objects) const; - void ServerSendScripts(NetworkConnection& connection); - - // Handlers - void ServerHandleRequestGamestate(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandleHeartbeat(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandleAuth(NetworkConnection& connection, NetworkPacket& packet); - void ServerClientJoined(std::string_view name, const std::string& keyhash, NetworkConnection& connection); - void ServerHandleChat(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandleGameAction(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandlePing(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandleGameInfo(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandleToken(NetworkConnection& connection, NetworkPacket& packet); - void ServerHandleMapRequest(NetworkConnection& connection, NetworkPacket& packet); - -public: // Client - void Reconnect(); - int32_t GetMode() const noexcept; - NetworkAuth GetAuthStatus(); - int32_t GetStatus() const noexcept; - uint8_t GetPlayerID() const noexcept; - void ProcessPlayerInfo(); - void ProcessDisconnectedClients(); - static const char* FormatChat(NetworkPlayer* fromplayer, const char* text); - void SendPacketToClients(const NetworkPacket& packet, bool front = false, bool gameCmd = false) const; - bool CheckSRAND(uint32_t tick, uint32_t srand0); - bool CheckDesynchronizaton(); - void RequestStateSnapshot(); - bool IsDesynchronised() const noexcept; - NetworkServerState GetServerState() const noexcept; - void ServerClientDisconnected(); - bool LoadMap(OpenRCT2::IStream* stream); - void UpdateClient(); - - // Packet dispatchers. - void Client_Send_RequestGameState(uint32_t tick); - void Client_Send_TOKEN(); - void Client_Send_AUTH( - const std::string& name, const std::string& password, const std::string& pubkey, const std::vector& signature); - void Client_Send_CHAT(const char* text); - void Client_Send_GAME_ACTION(const OpenRCT2::GameActions::GameAction* action); - void Client_Send_PING(); - void Client_Send_GAMEINFO(); - void Client_Send_MAPREQUEST(const std::vector& objects); - void Client_Send_HEARTBEAT(NetworkConnection& connection) const; - - // Handlers. - void Client_Handle_AUTH(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_MAP(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_CHAT(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_GAME_ACTION(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_TICK(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_PLAYERINFO(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_PLAYERLIST(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_PING(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_PINGLIST(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_SETDISCONNECTMSG(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_GAMEINFO(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_SHOWERROR(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_GROUPLIST(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_EVENT(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_OBJECTS_LIST(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_SCRIPTS_HEADER(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_SCRIPTS_DATA(NetworkConnection& connection, NetworkPacket& packet); - void Client_Handle_GAMESTATE(NetworkConnection& connection, NetworkPacket& packet); - - std::vector _challenge; - std::map _gameActionCallbacks; - NetworkKey _key; - NetworkUserManager _userManager; - -public: // Public common - std::string ServerName; - std::string ServerDescription; - std::string ServerGreeting; - std::string ServerProviderName; - std::string ServerProviderEmail; - std::string ServerProviderWebsite; - std::vector> player_list; - std::vector> group_list; - bool IsServerPlayerInvisible = false; - -private: // Common Data - using CommandHandler = void (NetworkBase::*)(NetworkConnection& connection, NetworkPacket& packet); - - std::vector chunk_buffer; - std::ofstream _chat_log_fs; - uint32_t _lastUpdateTime = 0; - uint32_t _currentDeltaTime = 0; - int32_t mode = NETWORK_MODE_NONE; - uint8_t default_group = 0; - bool _closeLock = false; - bool _requireClose = false; - -private: // Server Data - std::unordered_map server_command_handlers; - std::unique_ptr _listenSocket; - std::unique_ptr _advertiser; - std::list> client_connection_list; - std::string _serverLogPath; - std::string _serverLogFilenameFormat = "%Y%m%d-%H%M%S.txt"; - std::ofstream _server_log_fs; - uint16_t listening_port = 0; - bool _playerListInvalidated = false; - -private: // Client Data - struct PlayerListUpdate + class NetworkBase : public OpenRCT2::System { - std::vector players; - }; + public: + NetworkBase(OpenRCT2::IContext& context); - struct ServerTickData - { - uint32_t srand0; - uint32_t tick; - std::string spriteHash; - }; + public: // Uncategorized + bool BeginServer(uint16_t port, const std::string& address); + bool BeginClient(const std::string& host, uint16_t port); - struct ServerScriptsData - { - uint32_t pluginCount{}; - uint32_t dataSize{}; - OpenRCT2::MemoryStream data; - }; + public: // Common + bool Init(); + void Close(); + uint32_t GetServerTick() const noexcept; + // FIXME: This is currently the wrong function to override in System, will be refactored later. + void Update() override final; + void Flush(); + void ProcessPending(); + void ProcessPlayerList(); + auto GetPlayerIteratorByID(uint8_t id) const; + auto GetGroupIteratorByID(uint8_t id) const; + NetworkPlayer* GetPlayerByID(uint8_t id) const; + NetworkGroup* GetGroupByID(uint8_t id) const; + int32_t GetTotalNumPlayers() const noexcept; + int32_t GetNumVisiblePlayers() const noexcept; + void SetPassword(u8string_view password); + uint8_t GetDefaultGroup() const noexcept; + std::string BeginLog(const std::string& directory, const std::string& midName, const std::string& filenameFormat); + void AppendLog(std::ostream& fs, std::string_view s); + void BeginChatLog(); + void AppendChatLog(std::string_view s); + void CloseChatLog(); + NetworkStats GetStats() const; + json_t GetServerInfoAsJson() const; + bool ProcessConnection(NetworkConnection& connection); + void CloseConnection(); + NetworkPlayer* AddPlayer(const std::string& name, const std::string& keyhash); + void ProcessPacket(NetworkConnection& connection, NetworkPacket& packet); - std::unordered_map client_command_handlers; - std::unique_ptr _serverConnection; - std::map _pendingPlayerLists; - std::multimap _pendingPlayerInfo; - std::map _serverTickData; - std::vector _missingObjects; - std::string _host; - std::string _chatLogPath; - std::string _chatLogFilenameFormat = "%Y%m%d-%H%M%S.txt"; - std::string _password; - OpenRCT2::MemoryStream _serverGameState; - NetworkServerState _serverState; - uint32_t _lastSentHeartbeat = 0; - uint32_t last_ping_sent_time = 0; - uint32_t server_connect_time = 0; - uint32_t _actionId; - int32_t status = NETWORK_STATUS_NONE; - uint8_t player_id = 0; - uint16_t _port = 0; - SocketStatus _lastConnectStatus = SocketStatus::Closed; - bool _requireReconnect = false; - bool _clientMapLoaded = false; - ServerScriptsData _serverScriptsData{}; -}; + public: // Server + NetworkConnection* GetPlayerConnection(uint8_t id) const; + void KickPlayer(int32_t playerId); + NetworkGroup* AddGroup(); + void LoadGroups(); + void SetDefaultGroup(uint8_t id); + void SaveGroups(); + void RemoveGroup(uint8_t id); + uint8_t GetGroupIDByHash(const std::string& keyhash); + void BeginServerLog(); + void AppendServerLog(const std::string& s); + void CloseServerLog(); + void DecayCooldown(NetworkPlayer* player); + void AddClient(std::unique_ptr&& socket); + std::string GetMasterServerUrl(); + std::string GenerateAdvertiseKey(); + void SetupDefaultGroups(); + void RemovePlayer(std::unique_ptr& connection); + void UpdateServer(); + void ServerClientDisconnected(std::unique_ptr& connection); + bool SaveMap(OpenRCT2::IStream* stream, const std::vector& objects) const; + std::vector SaveForNetwork(const std::vector& objects) const; + std::string MakePlayerNameUnique(const std::string& name); + + // Packet dispatchers. + void ServerSendAuth(NetworkConnection& connection); + void ServerSendToken(NetworkConnection& connection); + void ServerSendMap(NetworkConnection* connection = nullptr); + void ServerSendChat(const char* text, const std::vector& playerIds = {}); + void ServerSendGameAction(const OpenRCT2::GameActions::GameAction* action); + void ServerSendTick(); + void ServerSendPlayerInfo(int32_t playerId); + void ServerSendPlayerList(); + void ServerSendPing(); + void ServerSendPingList(); + void ServerSendSetDisconnectMsg(NetworkConnection& connection, const char* msg); + void ServerSendGameInfo(NetworkConnection& connection); + void ServerSendShowError(NetworkConnection& connection, StringId title, StringId message); + void ServerSendGroupList(NetworkConnection& connection); + void ServerSendEventPlayerJoined(const char* playerName); + void ServerSendEventPlayerDisconnected(const char* playerName, const char* reason); + void ServerSendObjectsList( + NetworkConnection& connection, const std::vector& objects) const; + void ServerSendScripts(NetworkConnection& connection); + + // Handlers + void ServerHandleRequestGamestate(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandleHeartbeat(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandleAuth(NetworkConnection& connection, NetworkPacket& packet); + void ServerClientJoined(std::string_view name, const std::string& keyhash, NetworkConnection& connection); + void ServerHandleChat(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandleGameAction(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandlePing(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandleGameInfo(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandleToken(NetworkConnection& connection, NetworkPacket& packet); + void ServerHandleMapRequest(NetworkConnection& connection, NetworkPacket& packet); + + public: // Client + void Reconnect(); + int32_t GetMode() const noexcept; + NetworkAuth GetAuthStatus(); + int32_t GetStatus() const noexcept; + uint8_t GetPlayerID() const noexcept; + void ProcessPlayerInfo(); + void ProcessDisconnectedClients(); + static const char* FormatChat(NetworkPlayer* fromplayer, const char* text); + void SendPacketToClients(const NetworkPacket& packet, bool front = false, bool gameCmd = false) const; + bool CheckSRAND(uint32_t tick, uint32_t srand0); + bool CheckDesynchronizaton(); + void RequestStateSnapshot(); + bool IsDesynchronised() const noexcept; + NetworkServerState GetServerState() const noexcept; + void ServerClientDisconnected(); + bool LoadMap(OpenRCT2::IStream* stream); + void UpdateClient(); + + // Packet dispatchers. + void Client_Send_RequestGameState(uint32_t tick); + void Client_Send_TOKEN(); + void Client_Send_AUTH( + const std::string& name, const std::string& password, const std::string& pubkey, + const std::vector& signature); + void Client_Send_CHAT(const char* text); + void Client_Send_GAME_ACTION(const OpenRCT2::GameActions::GameAction* action); + void Client_Send_PING(); + void Client_Send_GAMEINFO(); + void Client_Send_MAPREQUEST(const std::vector& objects); + void Client_Send_HEARTBEAT(NetworkConnection& connection) const; + + // Handlers. + void Client_Handle_AUTH(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_MAP(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_CHAT(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_GAME_ACTION(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_TICK(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_PLAYERINFO(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_PLAYERLIST(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_PING(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_PINGLIST(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_SETDISCONNECTMSG(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_GAMEINFO(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_SHOWERROR(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_GROUPLIST(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_EVENT(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_OBJECTS_LIST(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_SCRIPTS_HEADER(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_SCRIPTS_DATA(NetworkConnection& connection, NetworkPacket& packet); + void Client_Handle_GAMESTATE(NetworkConnection& connection, NetworkPacket& packet); + + std::vector _challenge; + std::map _gameActionCallbacks; + NetworkKey _key; + NetworkUserManager _userManager; + + public: // Public common + std::string ServerName; + std::string ServerDescription; + std::string ServerGreeting; + std::string ServerProviderName; + std::string ServerProviderEmail; + std::string ServerProviderWebsite; + std::vector> player_list; + std::vector> group_list; + bool IsServerPlayerInvisible = false; + + private: // Common Data + using CommandHandler = void (NetworkBase::*)(NetworkConnection& connection, NetworkPacket& packet); + + std::vector chunk_buffer; + std::ofstream _chat_log_fs; + uint32_t _lastUpdateTime = 0; + uint32_t _currentDeltaTime = 0; + int32_t mode = NETWORK_MODE_NONE; + uint8_t default_group = 0; + bool _closeLock = false; + bool _requireClose = false; + + private: // Server Data + std::unordered_map server_command_handlers; + std::unique_ptr _listenSocket; + std::unique_ptr _advertiser; + std::list> client_connection_list; + std::string _serverLogPath; + std::string _serverLogFilenameFormat = "%Y%m%d-%H%M%S.txt"; + std::ofstream _server_log_fs; + uint16_t listening_port = 0; + bool _playerListInvalidated = false; + + private: // Client Data + struct PlayerListUpdate + { + std::vector players; + }; + + struct ServerTickData + { + uint32_t srand0; + uint32_t tick; + std::string spriteHash; + }; + + struct ServerScriptsData + { + uint32_t pluginCount{}; + uint32_t dataSize{}; + OpenRCT2::MemoryStream data; + }; + + std::unordered_map client_command_handlers; + std::unique_ptr _serverConnection; + std::map _pendingPlayerLists; + std::multimap _pendingPlayerInfo; + std::map _serverTickData; + std::vector _missingObjects; + std::string _host; + std::string _chatLogPath; + std::string _chatLogFilenameFormat = "%Y%m%d-%H%M%S.txt"; + std::string _password; + OpenRCT2::MemoryStream _serverGameState; + NetworkServerState _serverState; + uint32_t _lastSentHeartbeat = 0; + uint32_t last_ping_sent_time = 0; + uint32_t server_connect_time = 0; + uint32_t _actionId; + int32_t status = NETWORK_STATUS_NONE; + uint8_t player_id = 0; + uint16_t _port = 0; + SocketStatus _lastConnectStatus = SocketStatus::Closed; + bool _requireReconnect = false; + bool _clientMapLoaded = false; + ServerScriptsData _serverScriptsData{}; + }; +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkClient.h b/src/openrct2/network/NetworkClient.h index 34b3fd3d5e..ef49fb2d6f 100644 --- a/src/openrct2/network/NetworkClient.h +++ b/src/openrct2/network/NetworkClient.h @@ -4,9 +4,12 @@ #ifndef DISABLE_NETWORK -class NetworkClient final : public NetworkBase +namespace OpenRCT2::Network { -public: -}; + class NetworkClient final : public NetworkBase + { + public: + }; +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkConnection.cpp b/src/openrct2/network/NetworkConnection.cpp index 1adc1689f5..a070678236 100644 --- a/src/openrct2/network/NetworkConnection.cpp +++ b/src/openrct2/network/NetworkConnection.cpp @@ -19,217 +19,218 @@ #include -using namespace OpenRCT2; - -static constexpr size_t kNetworkDisconnectReasonBufSize = 256; -static constexpr size_t kNetworkBufferSize = (1024 * 64) - 1; // 64 KiB, maximum packet size. +namespace OpenRCT2::Network +{ + static constexpr size_t kNetworkDisconnectReasonBufSize = 256; + static constexpr size_t kNetworkBufferSize = (1024 * 64) - 1; // 64 KiB, maximum packet size. #ifndef DEBUG -static constexpr size_t kNetworkNoDataTimeout = 20; // Seconds. + static constexpr size_t kNetworkNoDataTimeout = 20; // Seconds. #endif -static_assert(kNetworkBufferSize <= std::numeric_limits::max(), "kNetworkBufferSize too big, uint16_t is max."); + static_assert(kNetworkBufferSize <= std::numeric_limits::max(), "kNetworkBufferSize too big, uint16_t is max."); -NetworkConnection::NetworkConnection() noexcept -{ - ResetLastPacketTime(); -} - -NetworkReadPacket NetworkConnection::ReadPacket() -{ - size_t bytesRead = 0; - - // Read packet header. - auto& header = InboundPacket.Header; - if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) + NetworkConnection::NetworkConnection() noexcept { - const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred; - - uint8_t* buffer = reinterpret_cast(&InboundPacket.Header); - - NetworkReadPacket status = Socket->ReceiveData(buffer, missingLength, &bytesRead); - if (status != NetworkReadPacket::Success) - { - return status; - } - - InboundPacket.BytesTransferred += bytesRead; - if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) - { - // If still not enough data for header, keep waiting. - return NetworkReadPacket::MoreData; - } - - // Normalise values. - header.Size = Convert::NetworkToHost(header.Size); - header.Id = ByteSwapBE(header.Id); - - // NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size. - // Previously the Id field was not part of the header rather part of the body. - header.Size -= std::min(header.Size, sizeof(header.Id)); - - // Fall-through: Read rest of packet. + ResetLastPacketTime(); } - // Read packet body. + NetworkReadPacket NetworkConnection::ReadPacket() { - // NOTE: BytesTransfered includes the header length, this will not underflow. - const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header)); + size_t bytesRead = 0; - uint8_t buffer[kNetworkBufferSize]; - - if (missingLength > 0) + // Read packet header. + auto& header = InboundPacket.Header; + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) { - NetworkReadPacket status = Socket->ReceiveData(buffer, std::min(missingLength, kNetworkBufferSize), &bytesRead); + const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred; + + uint8_t* buffer = reinterpret_cast(&InboundPacket.Header); + + NetworkReadPacket status = Socket->ReceiveData(buffer, missingLength, &bytesRead); if (status != NetworkReadPacket::Success) { return status; } InboundPacket.BytesTransferred += bytesRead; - InboundPacket.Write(buffer, bytesRead); + if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header)) + { + // If still not enough data for header, keep waiting. + return NetworkReadPacket::MoreData; + } + + // Normalise values. + header.Size = Convert::NetworkToHost(header.Size); + header.Id = ByteSwapBE(header.Id); + + // NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size. + // Previously the Id field was not part of the header rather part of the body. + header.Size -= std::min(header.Size, sizeof(header.Id)); + + // Fall-through: Read rest of packet. } - if (InboundPacket.Data.size() == header.Size) + // Read packet body. { - // Received complete packet. - _lastPacketTime = Platform::GetTicks(); + // NOTE: BytesTransfered includes the header length, this will not underflow. + const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header)); - RecordPacketStats(InboundPacket, false); + uint8_t buffer[kNetworkBufferSize]; - return NetworkReadPacket::Success; + if (missingLength > 0) + { + NetworkReadPacket status = Socket->ReceiveData(buffer, std::min(missingLength, kNetworkBufferSize), &bytesRead); + if (status != NetworkReadPacket::Success) + { + return status; + } + + InboundPacket.BytesTransferred += bytesRead; + InboundPacket.Write(buffer, bytesRead); + } + + if (InboundPacket.Data.size() == header.Size) + { + // Received complete packet. + _lastPacketTime = Platform::GetTicks(); + + RecordPacketStats(InboundPacket, false); + + return NetworkReadPacket::Success; + } + } + + return NetworkReadPacket::MoreData; + } + + static sfl::small_vector serializePacket(const NetworkPacket& packet) + { + // NOTE: For compatibility reasons for the master server we need to add sizeof(Header.Id) to the size. + // Previously the Id field was not part of the header rather part of the body. + const auto bodyLength = packet.Data.size() + sizeof(packet.Header.Id); + + Guard::Assert(bodyLength <= std::numeric_limits::max(), "Packet size too large"); + + auto header = packet.Header; + header.Size = static_cast(bodyLength); + header.Size = Convert::HostToNetwork(header.Size); + header.Id = ByteSwapBE(header.Id); + + sfl::small_vector buffer; + buffer.reserve(sizeof(header) + packet.Data.size()); + + buffer.insert(buffer.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); + buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end()); + + return buffer; + } + + void NetworkConnection::QueuePacket(const NetworkPacket& packet, bool front) + { + if (AuthStatus == NetworkAuth::Ok || !packet.CommandRequiresAuth()) + { + const auto payload = serializePacket(packet); + if (front) + { + _outboundBuffer.insert(_outboundBuffer.begin(), payload.begin(), payload.end()); + } + else + { + _outboundBuffer.insert(_outboundBuffer.end(), payload.begin(), payload.end()); + } + + RecordPacketStats(packet, true); } } - return NetworkReadPacket::MoreData; -} - -static sfl::small_vector serializePacket(const NetworkPacket& packet) -{ - // NOTE: For compatibility reasons for the master server we need to add sizeof(Header.Id) to the size. - // Previously the Id field was not part of the header rather part of the body. - const auto bodyLength = packet.Data.size() + sizeof(packet.Header.Id); - - Guard::Assert(bodyLength <= std::numeric_limits::max(), "Packet size too large"); - - auto header = packet.Header; - header.Size = static_cast(bodyLength); - header.Size = Convert::HostToNetwork(header.Size); - header.Id = ByteSwapBE(header.Id); - - sfl::small_vector buffer; - buffer.reserve(sizeof(header) + packet.Data.size()); - - buffer.insert(buffer.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); - buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end()); - - return buffer; -} - -void NetworkConnection::QueuePacket(const NetworkPacket& packet, bool front) -{ - if (AuthStatus == NetworkAuth::Ok || !packet.CommandRequiresAuth()) + void NetworkConnection::Disconnect() noexcept { - const auto payload = serializePacket(packet); - if (front) + ShouldDisconnect = true; + } + + bool NetworkConnection::IsValid() const + { + return !ShouldDisconnect && Socket->GetStatus() == SocketStatus::Connected; + } + + void NetworkConnection::SendQueuedData() + { + if (_outboundBuffer.empty()) { - _outboundBuffer.insert(_outboundBuffer.begin(), payload.begin(), payload.end()); + return; + } + + const auto bytesSent = Socket->SendData(_outboundBuffer.data(), _outboundBuffer.size()); + + if (bytesSent > 0) + { + _outboundBuffer.erase(_outboundBuffer.begin(), _outboundBuffer.begin() + bytesSent); + } + } + + void NetworkConnection::ResetLastPacketTime() noexcept + { + _lastPacketTime = Platform::GetTicks(); + } + + bool NetworkConnection::ReceivedPacketRecently() const noexcept + { + #ifndef DEBUG + constexpr auto kTimeoutMs = kNetworkNoDataTimeout * 1000; + if (Platform::GetTicks() > _lastPacketTime + kTimeoutMs) + { + return false; + } + #endif + return true; + } + + const utf8* NetworkConnection::GetLastDisconnectReason() const noexcept + { + return this->_lastDisconnectReason.c_str(); + } + + void NetworkConnection::SetLastDisconnectReason(std::string_view src) + { + _lastDisconnectReason = src; + } + + void NetworkConnection::SetLastDisconnectReason(const StringId string_id, void* args) + { + char buffer[kNetworkDisconnectReasonBufSize]; + OpenRCT2::FormatStringLegacy(buffer, kNetworkDisconnectReasonBufSize, string_id, args); + SetLastDisconnectReason(buffer); + } + + void NetworkConnection::RecordPacketStats(const NetworkPacket& packet, bool sending) + { + uint32_t packetSize = static_cast(packet.BytesTransferred); + NetworkStatisticsGroup trafficGroup; + + switch (packet.GetCommand()) + { + case NetworkCommand::GameAction: + trafficGroup = NetworkStatisticsGroup::Commands; + break; + case NetworkCommand::Map: + trafficGroup = NetworkStatisticsGroup::MapData; + break; + default: + trafficGroup = NetworkStatisticsGroup::Base; + break; + } + + if (sending) + { + Stats.bytesSent[EnumValue(trafficGroup)] += packetSize; + Stats.bytesSent[EnumValue(NetworkStatisticsGroup::Total)] += packetSize; } else { - _outboundBuffer.insert(_outboundBuffer.end(), payload.begin(), payload.end()); + Stats.bytesReceived[EnumValue(trafficGroup)] += packetSize; + Stats.bytesReceived[EnumValue(NetworkStatisticsGroup::Total)] += packetSize; } - - RecordPacketStats(packet, true); } -} - -void NetworkConnection::Disconnect() noexcept -{ - ShouldDisconnect = true; -} - -bool NetworkConnection::IsValid() const -{ - return !ShouldDisconnect && Socket->GetStatus() == SocketStatus::Connected; -} - -void NetworkConnection::SendQueuedData() -{ - if (_outboundBuffer.empty()) - { - return; - } - - const auto bytesSent = Socket->SendData(_outboundBuffer.data(), _outboundBuffer.size()); - - if (bytesSent > 0) - { - _outboundBuffer.erase(_outboundBuffer.begin(), _outboundBuffer.begin() + bytesSent); - } -} - -void NetworkConnection::ResetLastPacketTime() noexcept -{ - _lastPacketTime = Platform::GetTicks(); -} - -bool NetworkConnection::ReceivedPacketRecently() const noexcept -{ - #ifndef DEBUG - constexpr auto kTimeoutMs = kNetworkNoDataTimeout * 1000; - if (Platform::GetTicks() > _lastPacketTime + kTimeoutMs) - { - return false; - } - #endif - return true; -} - -const utf8* NetworkConnection::GetLastDisconnectReason() const noexcept -{ - return this->_lastDisconnectReason.c_str(); -} - -void NetworkConnection::SetLastDisconnectReason(std::string_view src) -{ - _lastDisconnectReason = src; -} - -void NetworkConnection::SetLastDisconnectReason(const StringId string_id, void* args) -{ - char buffer[kNetworkDisconnectReasonBufSize]; - OpenRCT2::FormatStringLegacy(buffer, kNetworkDisconnectReasonBufSize, string_id, args); - SetLastDisconnectReason(buffer); -} - -void NetworkConnection::RecordPacketStats(const NetworkPacket& packet, bool sending) -{ - uint32_t packetSize = static_cast(packet.BytesTransferred); - NetworkStatisticsGroup trafficGroup; - - switch (packet.GetCommand()) - { - case NetworkCommand::GameAction: - trafficGroup = NetworkStatisticsGroup::Commands; - break; - case NetworkCommand::Map: - trafficGroup = NetworkStatisticsGroup::MapData; - break; - default: - trafficGroup = NetworkStatisticsGroup::Base; - break; - } - - if (sending) - { - Stats.bytesSent[EnumValue(trafficGroup)] += packetSize; - Stats.bytesSent[EnumValue(NetworkStatisticsGroup::Total)] += packetSize; - } - else - { - Stats.bytesReceived[EnumValue(trafficGroup)] += packetSize; - Stats.bytesReceived[EnumValue(NetworkStatisticsGroup::Total)] += packetSize; - } -} +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/NetworkConnection.h b/src/openrct2/network/NetworkConnection.h index 6850fd6012..874f39eb93 100644 --- a/src/openrct2/network/NetworkConnection.h +++ b/src/openrct2/network/NetworkConnection.h @@ -20,51 +20,54 @@ #include #include -class NetworkPlayer; - namespace OpenRCT2 { struct ObjectRepositoryItem; } -class NetworkConnection final +namespace OpenRCT2::Network { -public: - std::unique_ptr Socket = nullptr; - NetworkPacket InboundPacket; - NetworkAuth AuthStatus = NetworkAuth::None; - NetworkStats Stats = {}; - NetworkPlayer* Player = nullptr; - uint32_t PingTime = 0; - NetworkKey Key; - std::vector Challenge; - std::vector RequestedObjects; - bool ShouldDisconnect = false; + class NetworkPlayer; - NetworkConnection() noexcept; + class NetworkConnection final + { + public: + std::unique_ptr Socket = nullptr; + NetworkPacket InboundPacket; + NetworkAuth AuthStatus = NetworkAuth::None; + NetworkStats Stats = {}; + NetworkPlayer* Player = nullptr; + uint32_t PingTime = 0; + NetworkKey Key; + std::vector Challenge; + std::vector RequestedObjects; + bool ShouldDisconnect = false; - NetworkReadPacket ReadPacket(); - void QueuePacket(const NetworkPacket& packet, bool front = false); + NetworkConnection() noexcept; - // This will not immediately disconnect the client. The disconnect - // will happen post-tick. - void Disconnect() noexcept; + NetworkReadPacket ReadPacket(); + void QueuePacket(const NetworkPacket& packet, bool front = false); - bool IsValid() const; - void SendQueuedData(); - void ResetLastPacketTime() noexcept; - bool ReceivedPacketRecently() const noexcept; + // This will not immediately disconnect the client. The disconnect + // will happen post-tick. + void Disconnect() noexcept; - const utf8* GetLastDisconnectReason() const noexcept; - void SetLastDisconnectReason(std::string_view src); - void SetLastDisconnectReason(const StringId string_id, void* args = nullptr); + bool IsValid() const; + void SendQueuedData(); + void ResetLastPacketTime() noexcept; + bool ReceivedPacketRecently() const noexcept; -private: - std::vector _outboundBuffer; - uint32_t _lastPacketTime = 0; - std::string _lastDisconnectReason; + const utf8* GetLastDisconnectReason() const noexcept; + void SetLastDisconnectReason(std::string_view src); + void SetLastDisconnectReason(const StringId string_id, void* args = nullptr); - void RecordPacketStats(const NetworkPacket& packet, bool sending); -}; + private: + std::vector _outboundBuffer; + uint32_t _lastPacketTime = 0; + std::string _lastDisconnectReason; + + void RecordPacketStats(const NetworkPacket& packet, bool sending); + }; +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkGroup.cpp b/src/openrct2/network/NetworkGroup.cpp index e83d93e6b7..7dfd14884d 100644 --- a/src/openrct2/network/NetworkGroup.cpp +++ b/src/openrct2/network/NetworkGroup.cpp @@ -15,119 +15,120 @@ #include "NetworkAction.h" #include "NetworkTypes.h" -using namespace OpenRCT2; - -NetworkGroup NetworkGroup::FromJson(const json_t& jsonData) +namespace OpenRCT2::Network { - Guard::Assert(jsonData.is_object(), "NetworkGroup::FromJson expects parameter jsonData to be object"); - - NetworkGroup group; - json_t jsonId = jsonData["id"]; - json_t jsonName = jsonData["name"]; - json_t jsonPermissions = jsonData["permissions"]; - - if (jsonId.is_null() || jsonName.is_null() || jsonPermissions.is_null()) + NetworkGroup NetworkGroup::FromJson(const json_t& jsonData) { - throw std::runtime_error("Missing group data"); + Guard::Assert(jsonData.is_object(), "NetworkGroup::FromJson expects parameter jsonData to be object"); + + NetworkGroup group; + json_t jsonId = jsonData["id"]; + json_t jsonName = jsonData["name"]; + json_t jsonPermissions = jsonData["permissions"]; + + if (jsonId.is_null() || jsonName.is_null() || jsonPermissions.is_null()) + { + throw std::runtime_error("Missing group data"); + } + + group.Id = Json::GetNumber(jsonId); + group._name = Json::GetString(jsonName); + std::fill(group.ActionsAllowed.begin(), group.ActionsAllowed.end(), 0); + + for (const auto& jsonValue : jsonPermissions) + { + const std::string permission = Json::GetString(jsonValue); + + NetworkPermission action_id = NetworkActions::FindCommandByPermissionName(permission); + if (action_id != NetworkPermission::Count) + { + group.ToggleActionPermission(action_id); + } + } + return group; } - group.Id = Json::GetNumber(jsonId); - group._name = Json::GetString(jsonName); - std::fill(group.ActionsAllowed.begin(), group.ActionsAllowed.end(), 0); - - for (const auto& jsonValue : jsonPermissions) + json_t NetworkGroup::ToJson() const { - const std::string permission = Json::GetString(jsonValue); - - NetworkPermission action_id = NetworkActions::FindCommandByPermissionName(permission); - if (action_id != NetworkPermission::Count) + json_t jsonGroup = { + { "id", Id }, + { "name", GetName() }, + }; + json_t actionsArray = json_t::array(); + for (size_t i = 0; i < NetworkActions::Actions.size(); i++) { - group.ToggleActionPermission(action_id); + if (CanPerformAction(static_cast(i))) + { + actionsArray.emplace_back(NetworkActions::Actions[i].PermissionName); + } + } + jsonGroup["permissions"] = actionsArray; + return jsonGroup; + } + + const std::string& NetworkGroup::GetName() const noexcept + { + return _name; + } + + void NetworkGroup::SetName(std::string_view name) + { + _name = name; + } + + void NetworkGroup::Read(NetworkPacket& packet) + { + packet >> Id; + SetName(packet.ReadString()); + for (auto& action : ActionsAllowed) + { + packet >> action; } } - return group; -} -json_t NetworkGroup::ToJson() const -{ - json_t jsonGroup = { - { "id", Id }, - { "name", GetName() }, - }; - json_t actionsArray = json_t::array(); - for (size_t i = 0; i < NetworkActions::Actions.size(); i++) + void NetworkGroup::Write(NetworkPacket& packet) const { - if (CanPerformAction(static_cast(i))) + packet << Id; + packet.WriteString(GetName().c_str()); + for (const auto& action : ActionsAllowed) { - actionsArray.emplace_back(NetworkActions::Actions[i].PermissionName); + packet << action; } } - jsonGroup["permissions"] = actionsArray; - return jsonGroup; -} -const std::string& NetworkGroup::GetName() const noexcept -{ - return _name; -} - -void NetworkGroup::SetName(std::string_view name) -{ - _name = name; -} - -void NetworkGroup::Read(NetworkPacket& packet) -{ - packet >> Id; - SetName(packet.ReadString()); - for (auto& action : ActionsAllowed) + void NetworkGroup::ToggleActionPermission(NetworkPermission index) { - packet >> action; + size_t index_st = static_cast(index); + size_t byte = index_st / 8; + size_t bit = index_st % 8; + if (byte >= ActionsAllowed.size()) + { + return; + } + ActionsAllowed[byte] ^= (1 << bit); } -} -void NetworkGroup::Write(NetworkPacket& packet) const -{ - packet << Id; - packet.WriteString(GetName().c_str()); - for (const auto& action : ActionsAllowed) + bool NetworkGroup::CanPerformAction(NetworkPermission index) const noexcept { - packet << action; + size_t index_st = static_cast(index); + size_t byte = index_st / 8; + size_t bit = index_st % 8; + if (byte >= ActionsAllowed.size()) + { + return false; + } + return (ActionsAllowed[byte] & (1 << bit)) != 0; } -} -void NetworkGroup::ToggleActionPermission(NetworkPermission index) -{ - size_t index_st = static_cast(index); - size_t byte = index_st / 8; - size_t bit = index_st % 8; - if (byte >= ActionsAllowed.size()) - { - return; - } - ActionsAllowed[byte] ^= (1 << bit); -} - -bool NetworkGroup::CanPerformAction(NetworkPermission index) const noexcept -{ - size_t index_st = static_cast(index); - size_t byte = index_st / 8; - size_t bit = index_st % 8; - if (byte >= ActionsAllowed.size()) + bool NetworkGroup::CanPerformCommand(GameCommand command) const { + NetworkPermission action = NetworkActions::FindCommand(command); + if (action != NetworkPermission::Count) + { + return CanPerformAction(action); + } return false; } - return (ActionsAllowed[byte] & (1 << bit)) != 0; -} - -bool NetworkGroup::CanPerformCommand(GameCommand command) const -{ - NetworkPermission action = NetworkActions::FindCommand(command); - if (action != NetworkPermission::Count) - { - return CanPerformAction(action); - } - return false; -} +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/NetworkGroup.h b/src/openrct2/network/NetworkGroup.h index ae7e284435..1cbf6d3c31 100644 --- a/src/openrct2/network/NetworkGroup.h +++ b/src/openrct2/network/NetworkGroup.h @@ -16,39 +16,42 @@ #include #include -enum class NetworkPermission : uint32_t; - -class NetworkGroup final +namespace OpenRCT2::Network { -public: - std::array ActionsAllowed{}; - uint8_t Id = 0; + enum class NetworkPermission : uint32_t; - /** - * Creates a NetworkGroup object from a JSON object - * - * @param json JSON data source - * @return A NetworkGroup object - * @note json is deliberately left non-const: json_t behaviour changes when const - */ - static NetworkGroup FromJson(const json_t& json); + class NetworkGroup final + { + public: + std::array ActionsAllowed{}; + uint8_t Id = 0; - const std::string& GetName() const noexcept; - void SetName(std::string_view name); + /** + * Creates a NetworkGroup object from a JSON object + * + * @param json JSON data source + * @return A NetworkGroup object + * @note json is deliberately left non-const: json_t behaviour changes when const + */ + static NetworkGroup FromJson(const json_t& json); - void Read(NetworkPacket& packet); - void Write(NetworkPacket& packet) const; - void ToggleActionPermission(NetworkPermission index); - bool CanPerformAction(NetworkPermission index) const noexcept; - bool CanPerformCommand(GameCommand command) const; + const std::string& GetName() const noexcept; + void SetName(std::string_view name); - /** - * Serialise a NetworkGroup object into a JSON object - * - * @return JSON representation of the NetworkGroup object - */ - json_t ToJson() const; + void Read(NetworkPacket& packet); + void Write(NetworkPacket& packet) const; + void ToggleActionPermission(NetworkPermission index); + bool CanPerformAction(NetworkPermission index) const noexcept; + bool CanPerformCommand(GameCommand command) const; -private: - std::string _name; -}; + /** + * Serialise a NetworkGroup object into a JSON object + * + * @return JSON representation of the NetworkGroup object + */ + json_t ToJson() const; + + private: + std::string _name; + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkKey.cpp b/src/openrct2/network/NetworkKey.cpp index 112537a5d0..ec98b4abb5 100644 --- a/src/openrct2/network/NetworkKey.cpp +++ b/src/openrct2/network/NetworkKey.cpp @@ -19,199 +19,200 @@ #include -using namespace OpenRCT2; - -NetworkKey::NetworkKey() = default; -NetworkKey::~NetworkKey() = default; - -void NetworkKey::Unload() +namespace OpenRCT2::Network { - _key = nullptr; -} + NetworkKey::NetworkKey() = default; + NetworkKey::~NetworkKey() = default; -bool NetworkKey::Generate() -{ - try + void NetworkKey::Unload() { - _key = Crypt::CreateRSAKey(); - _key->Generate(); - return true; - } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::Generate failed: %s", e.what()); - return false; - } -} - -bool NetworkKey::LoadPrivate(OpenRCT2::IStream* stream) -{ - Guard::ArgumentNotNull(stream); - - size_t size = static_cast(stream->GetLength()); - if (size == static_cast(-1)) - { - LOG_ERROR("unknown size, refusing to load key"); - return false; - } - if (size > 4 * 1024 * 1024) - { - LOG_ERROR("Key file suspiciously large, refusing to load it"); - return false; + _key = nullptr; } - std::string pem(size, '\0'); - stream->Read(pem.data(), pem.size()); - - try + bool NetworkKey::Generate() { - _key = Crypt::CreateRSAKey(); - _key->SetPrivate(pem); - return true; - } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::LoadPrivate failed: %s", e.what()); - return false; - } -} - -bool NetworkKey::LoadPublic(OpenRCT2::IStream* stream) -{ - Guard::ArgumentNotNull(stream); - - size_t size = static_cast(stream->GetLength()); - if (size == static_cast(-1)) - { - LOG_ERROR("unknown size, refusing to load key"); - return false; - } - if (size > 4 * 1024 * 1024) - { - LOG_ERROR("Key file suspiciously large, refusing to load it"); - return false; + try + { + _key = Crypt::CreateRSAKey(); + _key->Generate(); + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::Generate failed: %s", e.what()); + return false; + } } - std::string pem(size, '\0'); - stream->Read(pem.data(), pem.size()); - - try + bool NetworkKey::LoadPrivate(OpenRCT2::IStream* stream) { - _key = Crypt::CreateRSAKey(); - _key->SetPublic(pem); - return true; - } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::LoadPublic failed: %s", e.what()); - return false; - } -} + Guard::ArgumentNotNull(stream); -bool NetworkKey::SavePrivate(OpenRCT2::IStream* stream) -{ - try + size_t size = static_cast(stream->GetLength()); + if (size == static_cast(-1)) + { + LOG_ERROR("unknown size, refusing to load key"); + return false; + } + if (size > 4 * 1024 * 1024) + { + LOG_ERROR("Key file suspiciously large, refusing to load it"); + return false; + } + + std::string pem(size, '\0'); + stream->Read(pem.data(), pem.size()); + + try + { + _key = Crypt::CreateRSAKey(); + _key->SetPrivate(pem); + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::LoadPrivate failed: %s", e.what()); + return false; + } + } + + bool NetworkKey::LoadPublic(OpenRCT2::IStream* stream) + { + Guard::ArgumentNotNull(stream); + + size_t size = static_cast(stream->GetLength()); + if (size == static_cast(-1)) + { + LOG_ERROR("unknown size, refusing to load key"); + return false; + } + if (size > 4 * 1024 * 1024) + { + LOG_ERROR("Key file suspiciously large, refusing to load it"); + return false; + } + + std::string pem(size, '\0'); + stream->Read(pem.data(), pem.size()); + + try + { + _key = Crypt::CreateRSAKey(); + _key->SetPublic(pem); + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::LoadPublic failed: %s", e.what()); + return false; + } + } + + bool NetworkKey::SavePrivate(OpenRCT2::IStream* stream) + { + try + { + if (_key == nullptr) + { + throw std::runtime_error("No key loaded"); + } + auto pem = _key->GetPrivate(); + stream->Write(pem.data(), pem.size()); + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::SavePrivate failed: %s", e.what()); + return false; + } + } + + bool NetworkKey::SavePublic(OpenRCT2::IStream* stream) + { + try + { + if (_key == nullptr) + { + throw std::runtime_error("No key loaded"); + } + auto pem = _key->GetPublic(); + stream->Write(pem.data(), pem.size()); + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::SavePublic failed: %s", e.what()); + return false; + } + } + + std::string NetworkKey::PublicKeyString() { if (_key == nullptr) { throw std::runtime_error("No key loaded"); } - auto pem = _key->GetPrivate(); - stream->Write(pem.data(), pem.size()); - return true; + return _key->GetPublic(); } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::SavePrivate failed: %s", e.what()); - return false; - } -} -bool NetworkKey::SavePublic(OpenRCT2::IStream* stream) -{ - try + /** + * @brief NetworkKey::PublicKeyHash + * Computes a short, human-readable (e.g. asciif-ied hex) hash for a given + * public key. Serves a purpose of easy identification keys in multiplayer + * overview, multiplayer settings. + * + * In particular, any of digest functions applied to a standardised key + * representation, like PEM, will be sufficient. + * + * @return returns a string containing key hash. + */ + std::string NetworkKey::PublicKeyHash() { - if (_key == nullptr) + try { - throw std::runtime_error("No key loaded"); + std::string key = PublicKeyString(); + if (key.empty()) + { + throw std::runtime_error("No key found"); + } + auto hash = Crypt::SHA1(key.c_str(), key.size()); + return String::StringFromHex(hash); } - auto pem = _key->GetPublic(); - stream->Write(pem.data(), pem.size()); - return true; - } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::SavePublic failed: %s", e.what()); - return false; - } -} - -std::string NetworkKey::PublicKeyString() -{ - if (_key == nullptr) - { - throw std::runtime_error("No key loaded"); - } - return _key->GetPublic(); -} - -/** - * @brief NetworkKey::PublicKeyHash - * Computes a short, human-readable (e.g. asciif-ied hex) hash for a given - * public key. Serves a purpose of easy identification keys in multiplayer - * overview, multiplayer settings. - * - * In particular, any of digest functions applied to a standardised key - * representation, like PEM, will be sufficient. - * - * @return returns a string containing key hash. - */ -std::string NetworkKey::PublicKeyHash() -{ - try - { - std::string key = PublicKeyString(); - if (key.empty()) + catch (const std::exception& e) { - throw std::runtime_error("No key found"); + LOG_ERROR("Failed to create hash of public key: %s", e.what()); } - auto hash = Crypt::SHA1(key.c_str(), key.size()); - return String::StringFromHex(hash); + return nullptr; } - catch (const std::exception& e) - { - LOG_ERROR("Failed to create hash of public key: %s", e.what()); - } - return nullptr; -} -bool NetworkKey::Sign(const uint8_t* md, const size_t len, std::vector& signature) const -{ - try + bool NetworkKey::Sign(const uint8_t* md, const size_t len, std::vector& signature) const { - auto rsa = Crypt::CreateRSA(); - signature = rsa->SignData(*_key, md, len); - return true; + try + { + auto rsa = Crypt::CreateRSA(); + signature = rsa->SignData(*_key, md, len); + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::Sign failed: %s", e.what()); + return false; + } } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::Sign failed: %s", e.what()); - return false; - } -} -bool NetworkKey::Verify(const uint8_t* md, const size_t len, const std::vector& signature) const -{ - try + bool NetworkKey::Verify(const uint8_t* md, const size_t len, const std::vector& signature) const { - auto rsa = Crypt::CreateRSA(); - return rsa->VerifyData(*_key, md, len, signature.data(), signature.size()); + try + { + auto rsa = Crypt::CreateRSA(); + return rsa->VerifyData(*_key, md, len, signature.data(), signature.size()); + } + catch (const std::exception& e) + { + LOG_ERROR("NetworkKey::Verify failed: %s", e.what()); + return false; + } } - catch (const std::exception& e) - { - LOG_ERROR("NetworkKey::Verify failed: %s", e.what()); - return false; - } -} +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkKey.h b/src/openrct2/network/NetworkKey.h index 8f893fde72..28dc5c5027 100644 --- a/src/openrct2/network/NetworkKey.h +++ b/src/openrct2/network/NetworkKey.h @@ -25,25 +25,28 @@ namespace OpenRCT2::Crypt class RsaKey; } -class NetworkKey final +namespace OpenRCT2::Network { -public: - NetworkKey(); - ~NetworkKey(); - bool Generate(); - bool LoadPrivate(OpenRCT2::IStream* stream); - bool LoadPublic(OpenRCT2::IStream* stream); - bool SavePrivate(OpenRCT2::IStream* stream); - bool SavePublic(OpenRCT2::IStream* stream); - std::string PublicKeyString(); - std::string PublicKeyHash(); - void Unload(); - bool Sign(const uint8_t* md, const size_t len, std::vector& signature) const; - bool Verify(const uint8_t* md, const size_t len, const std::vector& signature) const; + class NetworkKey final + { + public: + NetworkKey(); + ~NetworkKey(); + bool Generate(); + bool LoadPrivate(OpenRCT2::IStream* stream); + bool LoadPublic(OpenRCT2::IStream* stream); + bool SavePrivate(OpenRCT2::IStream* stream); + bool SavePublic(OpenRCT2::IStream* stream); + std::string PublicKeyString(); + std::string PublicKeyHash(); + void Unload(); + bool Sign(const uint8_t* md, const size_t len, std::vector& signature) const; + bool Verify(const uint8_t* md, const size_t len, const std::vector& signature) const; -private: - NetworkKey(const NetworkKey&) = delete; - std::unique_ptr _key; -}; + private: + NetworkKey(const NetworkKey&) = delete; + std::unique_ptr _key; + }; +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkPacket.cpp b/src/openrct2/network/NetworkPacket.cpp index b3010f4015..8e3415ec39 100644 --- a/src/openrct2/network/NetworkPacket.cpp +++ b/src/openrct2/network/NetworkPacket.cpp @@ -15,97 +15,100 @@ #include -NetworkPacket::NetworkPacket(NetworkCommand id) noexcept - : Header{ 0, id } +namespace OpenRCT2::Network { -} - -uint8_t* NetworkPacket::GetData() noexcept -{ - return Data.data(); -} - -const uint8_t* NetworkPacket::GetData() const noexcept -{ - return Data.data(); -} - -NetworkCommand NetworkPacket::GetCommand() const noexcept -{ - return Header.Id; -} - -void NetworkPacket::Clear() noexcept -{ - BytesTransferred = 0; - BytesRead = 0; - Data.clear(); -} - -bool NetworkPacket::CommandRequiresAuth() const noexcept -{ - switch (GetCommand()) + NetworkPacket::NetworkPacket(NetworkCommand id) noexcept + : Header{ 0, id } { - case NetworkCommand::Ping: - case NetworkCommand::Auth: - case NetworkCommand::Token: - case NetworkCommand::GameInfo: - case NetworkCommand::ObjectsList: - case NetworkCommand::ScriptsHeader: - case NetworkCommand::ScriptsData: - case NetworkCommand::MapRequest: - case NetworkCommand::Heartbeat: - return false; - default: - return true; - } -} - -void NetworkPacket::Write(const void* bytes, size_t size) -{ - const uint8_t* src = reinterpret_cast(bytes); - Data.insert(Data.end(), src, src + size); -} - -void NetworkPacket::WriteString(std::string_view s) -{ - Write(s.data(), s.size()); - Data.push_back(0); -} - -const uint8_t* NetworkPacket::Read(size_t size) -{ - if (BytesRead + size > Data.size()) - { - return nullptr; } - const uint8_t* data = Data.data() + BytesRead; - BytesRead += size; - return data; -} - -std::string_view NetworkPacket::ReadString() -{ - if (BytesRead >= Data.size()) - return {}; - - const char* str = reinterpret_cast(Data.data() + BytesRead); - - size_t stringLen = 0; - while (BytesRead < Data.size() && str[stringLen] != '\0') + uint8_t* NetworkPacket::GetData() noexcept { + return Data.data(); + } + + const uint8_t* NetworkPacket::GetData() const noexcept + { + return Data.data(); + } + + NetworkCommand NetworkPacket::GetCommand() const noexcept + { + return Header.Id; + } + + void NetworkPacket::Clear() noexcept + { + BytesTransferred = 0; + BytesRead = 0; + Data.clear(); + } + + bool NetworkPacket::CommandRequiresAuth() const noexcept + { + switch (GetCommand()) + { + case NetworkCommand::Ping: + case NetworkCommand::Auth: + case NetworkCommand::Token: + case NetworkCommand::GameInfo: + case NetworkCommand::ObjectsList: + case NetworkCommand::ScriptsHeader: + case NetworkCommand::ScriptsData: + case NetworkCommand::MapRequest: + case NetworkCommand::Heartbeat: + return false; + default: + return true; + } + } + + void NetworkPacket::Write(const void* bytes, size_t size) + { + const uint8_t* src = reinterpret_cast(bytes); + Data.insert(Data.end(), src, src + size); + } + + void NetworkPacket::WriteString(std::string_view s) + { + Write(s.data(), s.size()); + Data.push_back(0); + } + + const uint8_t* NetworkPacket::Read(size_t size) + { + if (BytesRead + size > Data.size()) + { + return nullptr; + } + + const uint8_t* data = Data.data() + BytesRead; + BytesRead += size; + return data; + } + + std::string_view NetworkPacket::ReadString() + { + if (BytesRead >= Data.size()) + return {}; + + const char* str = reinterpret_cast(Data.data() + BytesRead); + + size_t stringLen = 0; + while (BytesRead < Data.size() && str[stringLen] != '\0') + { + BytesRead++; + stringLen++; + } + + if (str[stringLen] != '\0') + return {}; + + // Skip null terminator. BytesRead++; - stringLen++; + + return std::string_view(str, stringLen); } - - if (str[stringLen] != '\0') - return {}; - - // Skip null terminator. - BytesRead++; - - return std::string_view(str, stringLen); -} +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/NetworkPacket.h b/src/openrct2/network/NetworkPacket.h index b71e402255..2ed0fd0974 100644 --- a/src/openrct2/network/NetworkPacket.h +++ b/src/openrct2/network/NetworkPacket.h @@ -16,68 +16,71 @@ #include #include -#pragma pack(push, 1) -struct PacketHeader +namespace OpenRCT2::Network { - uint16_t Size = 0; - NetworkCommand Id = NetworkCommand::Invalid; -}; -static_assert(sizeof(PacketHeader) == 6); +#pragma pack(push, 1) + struct PacketHeader + { + uint16_t Size = 0; + NetworkCommand Id = NetworkCommand::Invalid; + }; + static_assert(sizeof(PacketHeader) == 6); #pragma pack(pop) -struct NetworkPacket final -{ - NetworkPacket() noexcept = default; - NetworkPacket(NetworkCommand id) noexcept; - - uint8_t* GetData() noexcept; - const uint8_t* GetData() const noexcept; - - NetworkCommand GetCommand() const noexcept; - - void Clear() noexcept; - bool CommandRequiresAuth() const noexcept; - - const uint8_t* Read(size_t size); - std::string_view ReadString(); - - void Write(const void* bytes, size_t size); - void WriteString(std::string_view s); - - template - NetworkPacket& operator>>(T& value) + struct NetworkPacket final { - if (BytesRead + sizeof(value) > Header.Size) + NetworkPacket() noexcept = default; + NetworkPacket(NetworkCommand id) noexcept; + + uint8_t* GetData() noexcept; + const uint8_t* GetData() const noexcept; + + NetworkCommand GetCommand() const noexcept; + + void Clear() noexcept; + bool CommandRequiresAuth() const noexcept; + + const uint8_t* Read(size_t size); + std::string_view ReadString(); + + void Write(const void* bytes, size_t size); + void WriteString(std::string_view s); + + template + NetworkPacket& operator>>(T& value) { - value = T{}; + if (BytesRead + sizeof(value) > Header.Size) + { + value = T{}; + } + else + { + T local; + std::memcpy(&local, &GetData()[BytesRead], sizeof(local)); + value = ByteSwapBE(local); + BytesRead += sizeof(value); + } + return *this; } - else + + template + NetworkPacket& operator<<(T value) { - T local; - std::memcpy(&local, &GetData()[BytesRead], sizeof(local)); - value = ByteSwapBE(local); - BytesRead += sizeof(value); + T swapped = ByteSwapBE(value); + Write(&swapped, sizeof(T)); + return *this; } - return *this; - } - template - NetworkPacket& operator<<(T value) - { - T swapped = ByteSwapBE(value); - Write(&swapped, sizeof(T)); - return *this; - } + NetworkPacket& operator<<(DataSerialiser& data) + { + Write(static_cast(data.GetStream().GetData()), data.GetStream().GetLength()); + return *this; + } - NetworkPacket& operator<<(DataSerialiser& data) - { - Write(static_cast(data.GetStream().GetData()), data.GetStream().GetLength()); - return *this; - } - -public: - PacketHeader Header{}; - sfl::small_vector Data; - size_t BytesTransferred = 0; - size_t BytesRead = 0; -}; + public: + PacketHeader Header{}; + sfl::small_vector Data; + size_t BytesTransferred = 0; + size_t BytesRead = 0; + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkPlayer.cpp b/src/openrct2/network/NetworkPlayer.cpp index 93d09dfea7..3a3c2316ce 100644 --- a/src/openrct2/network/NetworkPlayer.cpp +++ b/src/openrct2/network/NetworkPlayer.cpp @@ -15,39 +15,42 @@ #include "../ui/WindowManager.h" #include "NetworkPacket.h" -void NetworkPlayer::SetName(std::string_view name) +namespace OpenRCT2::Network { - // 36 == 31 + strlen(" #255"); - Name = name.substr(0, 36); -} + void NetworkPlayer::SetName(std::string_view name) + { + // 36 == 31 + strlen(" #255"); + Name = name.substr(0, 36); + } -void NetworkPlayer::Read(NetworkPacket& packet) -{ - auto name = packet.ReadString(); - SetName(name); - packet >> Id >> Flags >> Group >> LastAction >> LastActionCoord.x >> LastActionCoord.y >> LastActionCoord.z >> MoneySpent - >> CommandsRan; -} + void NetworkPlayer::Read(NetworkPacket& packet) + { + auto name = packet.ReadString(); + SetName(name); + packet >> Id >> Flags >> Group >> LastAction >> LastActionCoord.x >> LastActionCoord.y >> LastActionCoord.z + >> MoneySpent >> CommandsRan; + } -void NetworkPlayer::Write(NetworkPacket& packet) -{ - packet.WriteString(Name); - packet << Id << Flags << Group << LastAction << LastActionCoord.x << LastActionCoord.y << LastActionCoord.z << MoneySpent - << CommandsRan; -} + void NetworkPlayer::Write(NetworkPacket& packet) + { + packet.WriteString(Name); + packet << Id << Flags << Group << LastAction << LastActionCoord.x << LastActionCoord.y << LastActionCoord.z + << MoneySpent << CommandsRan; + } -void NetworkPlayer::IncrementNumCommands() -{ - CommandsRan++; - auto* windowMgr = OpenRCT2::Ui::GetWindowManager(); - windowMgr->InvalidateByNumber(WindowClass::Player, Id); -} + void NetworkPlayer::IncrementNumCommands() + { + CommandsRan++; + auto* windowMgr = OpenRCT2::Ui::GetWindowManager(); + windowMgr->InvalidateByNumber(WindowClass::Player, Id); + } -void NetworkPlayer::AddMoneySpent(money64 cost) -{ - MoneySpent += cost; - auto* windowMgr = OpenRCT2::Ui::GetWindowManager(); - windowMgr->InvalidateByNumber(WindowClass::Player, Id); -} + void NetworkPlayer::AddMoneySpent(money64 cost) + { + MoneySpent += cost; + auto* windowMgr = OpenRCT2::Ui::GetWindowManager(); + windowMgr->InvalidateByNumber(WindowClass::Player, Id); + } +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/NetworkPlayer.h b/src/openrct2/network/NetworkPlayer.h index 19a1b6201f..8b6cb2df14 100644 --- a/src/openrct2/network/NetworkPlayer.h +++ b/src/openrct2/network/NetworkPlayer.h @@ -17,34 +17,38 @@ #include #include -struct NetworkPacket; struct Peep; -class NetworkPlayer final +namespace OpenRCT2::Network { -public: - uint8_t Id = 0; - std::string Name; - uint16_t Ping = 0; - uint8_t Flags = 0; - uint8_t Group = 0; - money64 MoneySpent = 0.00_GBP; - uint32_t CommandsRan = 0; - int32_t LastAction = -999; - uint32_t LastActionTime = 0; - CoordsXYZ LastActionCoord = {}; - Peep* PickupPeep = nullptr; - int32_t PickupPeepOldX = kLocationNull; - std::string KeyHash; - uint32_t LastDemolishRideTime = 0; - uint32_t LastPlaceSceneryTime = 0; - std::unordered_map CooldownTime; - NetworkPlayer() noexcept = default; + struct NetworkPacket; - void SetName(std::string_view name); + class NetworkPlayer final + { + public: + uint8_t Id = 0; + std::string Name; + uint16_t Ping = 0; + uint8_t Flags = 0; + uint8_t Group = 0; + money64 MoneySpent = 0.00_GBP; + uint32_t CommandsRan = 0; + int32_t LastAction = -999; + uint32_t LastActionTime = 0; + CoordsXYZ LastActionCoord = {}; + Peep* PickupPeep = nullptr; + int32_t PickupPeepOldX = kLocationNull; + std::string KeyHash; + uint32_t LastDemolishRideTime = 0; + uint32_t LastPlaceSceneryTime = 0; + std::unordered_map CooldownTime; + NetworkPlayer() noexcept = default; - void Read(NetworkPacket& packet); - void Write(NetworkPacket& packet); - void IncrementNumCommands(); - void AddMoneySpent(money64 cost); -}; + void SetName(std::string_view name); + + void Read(NetworkPacket& packet); + void Write(NetworkPacket& packet); + void IncrementNumCommands(); + void AddMoneySpent(money64 cost); + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkServer.h b/src/openrct2/network/NetworkServer.h index 38b3986657..e553b8c280 100644 --- a/src/openrct2/network/NetworkServer.h +++ b/src/openrct2/network/NetworkServer.h @@ -4,9 +4,12 @@ #ifndef DISABLE_NETWORK -class NetworkServer final : public NetworkBase +namespace OpenRCT2::Network { -public: -}; + class NetworkServer final : public NetworkBase + { + public: + }; +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkServerAdvertiser.cpp b/src/openrct2/network/NetworkServerAdvertiser.cpp index 9474103c4b..3ce192cccd 100644 --- a/src/openrct2/network/NetworkServerAdvertiser.cpp +++ b/src/openrct2/network/NetworkServerAdvertiser.cpp @@ -34,330 +34,331 @@ #include #include -using namespace OpenRCT2; - -enum class MasterServerStatus +namespace OpenRCT2::Network { - Ok = 200, - InvalidToken = 401, - ServerNotFound = 404, - InternalError = 500 -}; + enum class MasterServerStatus + { + Ok = 200, + InvalidToken = 401, + ServerNotFound = 404, + InternalError = 500 + }; #ifndef DISABLE_HTTP -using namespace std::chrono_literals; -constexpr int32_t kMasterServerRegisterTime = std::chrono::milliseconds(2min).count(); -constexpr int32_t kMasterServerHeartbeatTime = std::chrono::milliseconds(1min).count(); + using namespace std::chrono_literals; + constexpr int32_t kMasterServerRegisterTime = std::chrono::milliseconds(2min).count(); + constexpr int32_t kMasterServerHeartbeatTime = std::chrono::milliseconds(1min).count(); #endif -class NetworkServerAdvertiser final : public INetworkServerAdvertiser -{ -private: - uint16_t _port; + class NetworkServerAdvertiser final : public INetworkServerAdvertiser + { + private: + uint16_t _port; - std::unique_ptr _lanListener; - uint32_t _lastListenTime{}; + std::unique_ptr _lanListener; + uint32_t _lastListenTime{}; - AdvertiseStatus _status = AdvertiseStatus::unregistered; + AdvertiseStatus _status = AdvertiseStatus::unregistered; #ifndef DISABLE_HTTP - uint32_t _lastAdvertiseTime = 0; - uint32_t _lastHeartbeatTime = 0; + uint32_t _lastAdvertiseTime = 0; + uint32_t _lastHeartbeatTime = 0; - // Our unique token for this server - std::string _token; + // Our unique token for this server + std::string _token; - // Key received from the master server - std::string _key; + // Key received from the master server + std::string _key; - // See https://github.com/OpenRCT2/OpenRCT2/issues/6277 and 4953 - bool _forceIPv4 = false; + // See https://github.com/OpenRCT2/OpenRCT2/issues/6277 and 4953 + bool _forceIPv4 = false; #endif -public: - explicit NetworkServerAdvertiser(uint16_t port) - { - _port = port; - _lanListener = CreateUdpSocket(); - #ifndef DISABLE_HTTP - _key = GenerateAdvertiseKey(); - #endif - } - - AdvertiseStatus GetStatus() const override - { - return _status; - } - - void Update() override - { - UpdateLAN(); - #ifndef DISABLE_HTTP - if (Config::Get().network.Advertise) + public: + explicit NetworkServerAdvertiser(uint16_t port) { - UpdateWAN(); + _port = port; + _lanListener = CreateUdpSocket(); + #ifndef DISABLE_HTTP + _key = GenerateAdvertiseKey(); + #endif } - #endif - } -private: - void UpdateLAN() - { - auto ticks = Platform::GetTicks(); - if (ticks > _lastListenTime + 500) + AdvertiseStatus GetStatus() const override { - if (_lanListener->GetStatus() != SocketStatus::Listening) + return _status; + } + + void Update() override + { + UpdateLAN(); + #ifndef DISABLE_HTTP + if (Config::Get().network.Advertise) { - _lanListener->Listen(kNetworkLanBroadcastPort); + UpdateWAN(); + } + #endif + } + + private: + void UpdateLAN() + { + auto ticks = Platform::GetTicks(); + if (ticks > _lastListenTime + 500) + { + if (_lanListener->GetStatus() != SocketStatus::Listening) + { + _lanListener->Listen(kNetworkLanBroadcastPort); + } + else + { + char buffer[256]{}; + size_t recievedBytes{}; + std::unique_ptr endpoint; + auto p = _lanListener->ReceiveData(buffer, sizeof(buffer) - 1, &recievedBytes, &endpoint); + if (p == NetworkReadPacket::Success) + { + std::string sender = endpoint->GetHostname(); + LOG_VERBOSE("Received %zu bytes from %s on LAN broadcast port", recievedBytes, sender.c_str()); + if (String::equals(buffer, kNetworkLanBroadcastMsg)) + { + auto body = GetBroadcastJson(); + auto bodyDump = body.dump(); + size_t sendLen = bodyDump.size() + 1; + LOG_VERBOSE("Sending %zu bytes back to %s", sendLen, sender.c_str()); + _lanListener->SendData(*endpoint, bodyDump.c_str(), sendLen); + } + } + } + _lastListenTime = ticks; + } + } + + json_t GetBroadcastJson() + { + json_t root = NetworkGetServerInfoAsJson(); + root["port"] = _port; + return root; + } + + #ifndef DISABLE_HTTP + void UpdateWAN() + { + switch (_status) + { + case AdvertiseStatus::unregistered: + if (_lastAdvertiseTime == 0 || Platform::GetTicks() > _lastAdvertiseTime + kMasterServerRegisterTime) + { + if (_lastAdvertiseTime == 0) + { + Console::WriteLine("Registering server on master server"); + } + SendRegistration(_forceIPv4); + } + break; + case AdvertiseStatus::registered: + if (Platform::GetTicks() > _lastHeartbeatTime + kMasterServerHeartbeatTime) + { + SendHeartbeat(); + } + break; + // exhaust enum values to satisfy clang + case AdvertiseStatus::disabled: + break; + } + } + + void SendRegistration(bool forceIPv4) + { + _lastAdvertiseTime = Platform::GetTicks(); + + // Send the registration request + Http::Request request; + request.url = GetMasterServerUrl(); + request.method = Http::Method::POST; + request.forceIPv4 = forceIPv4; + + json_t body = { + { "key", _key }, + { "port", _port }, + }; + + if (!Config::Get().network.AdvertiseAddress.empty()) + { + body["address"] = Config::Get().network.AdvertiseAddress; + } + + request.body = body.dump(); + request.header["Content-Type"] = "application/json"; + + Http::DoAsync(request, [&](Http::Response response) -> void { + if (response.status != Http::Status::Ok) + { + Console::Error::WriteLine("Unable to connect to master server"); + return; + } + + json_t root = Json::FromString(response.body); + root = Json::AsObject(root); + this->OnRegistrationResponse(root); + }); + } + + void SendHeartbeat() + { + Http::Request request; + request.url = GetMasterServerUrl(); + request.method = Http::Method::PUT; + + json_t body = GetHeartbeatJson(); + request.body = body.dump(); + request.header["Content-Type"] = "application/json"; + + _lastHeartbeatTime = Platform::GetTicks(); + Http::DoAsync(request, [&](Http::Response response) -> void { + if (response.status != Http::Status::Ok) + { + Console::Error::WriteLine("Unable to connect to master server"); + return; + } + + json_t root = Json::FromString(response.body); + root = Json::AsObject(root); + this->OnHeartbeatResponse(root); + }); + } + + /** + * @param jsonRoot must be of JSON type object or null + * @note jsonRoot is deliberately left non-const: json_t behaviour changes when const + */ + void OnRegistrationResponse(json_t& jsonRoot) + { + Guard::Assert(jsonRoot.is_object(), "OnRegistrationResponse expects parameter jsonRoot to be object"); + + auto status = Json::GetEnum(jsonRoot["status"], MasterServerStatus::InternalError); + + if (status == MasterServerStatus::Ok) + { + Console::WriteLine("Server successfully registered on master server"); + json_t jsonToken = jsonRoot["token"]; + if (jsonToken.is_string()) + { + _token = Json::GetString(jsonToken); + _status = AdvertiseStatus::registered; + } } else { - char buffer[256]{}; - size_t recievedBytes{}; - std::unique_ptr endpoint; - auto p = _lanListener->ReceiveData(buffer, sizeof(buffer) - 1, &recievedBytes, &endpoint); - if (p == NetworkReadPacket::Success) + std::string message = Json::GetString(jsonRoot["message"]); + if (message.empty()) { - std::string sender = endpoint->GetHostname(); - LOG_VERBOSE("Received %zu bytes from %s on LAN broadcast port", recievedBytes, sender.c_str()); - if (String::equals(buffer, kNetworkLanBroadcastMsg)) - { - auto body = GetBroadcastJson(); - auto bodyDump = body.dump(); - size_t sendLen = bodyDump.size() + 1; - LOG_VERBOSE("Sending %zu bytes back to %s", sendLen, sender.c_str()); - _lanListener->SendData(*endpoint, bodyDump.c_str(), sendLen); - } + message = "Invalid response from server"; + } + Console::Error::WriteLine( + "Unable to advertise (%d): %s\n * Check that you have port forwarded %u\n * Try setting " + "advertise_address in config.ini", + status, message.c_str(), _port); + + // Hack for https://github.com/OpenRCT2/OpenRCT2/issues/6277 + // Master server may not reply correctly if using IPv6, retry forcing IPv4, + // don't wait the full timeout. + if (!_forceIPv4 && status == MasterServerStatus::InternalError) + { + _forceIPv4 = true; + _lastAdvertiseTime = 0; + LOG_INFO("Forcing HTTP(S) over IPv4"); } } - _lastListenTime = ticks; - } - } - - json_t GetBroadcastJson() - { - json_t root = NetworkGetServerInfoAsJson(); - root["port"] = _port; - return root; - } - - #ifndef DISABLE_HTTP - void UpdateWAN() - { - switch (_status) - { - case AdvertiseStatus::unregistered: - if (_lastAdvertiseTime == 0 || Platform::GetTicks() > _lastAdvertiseTime + kMasterServerRegisterTime) - { - if (_lastAdvertiseTime == 0) - { - Console::WriteLine("Registering server on master server"); - } - SendRegistration(_forceIPv4); - } - break; - case AdvertiseStatus::registered: - if (Platform::GetTicks() > _lastHeartbeatTime + kMasterServerHeartbeatTime) - { - SendHeartbeat(); - } - break; - // exhaust enum values to satisfy clang - case AdvertiseStatus::disabled: - break; - } - } - - void SendRegistration(bool forceIPv4) - { - _lastAdvertiseTime = Platform::GetTicks(); - - // Send the registration request - Http::Request request; - request.url = GetMasterServerUrl(); - request.method = Http::Method::POST; - request.forceIPv4 = forceIPv4; - - json_t body = { - { "key", _key }, - { "port", _port }, - }; - - if (!Config::Get().network.AdvertiseAddress.empty()) - { - body["address"] = Config::Get().network.AdvertiseAddress; } - request.body = body.dump(); - request.header["Content-Type"] = "application/json"; - - Http::DoAsync(request, [&](Http::Response response) -> void { - if (response.status != Http::Status::Ok) - { - Console::Error::WriteLine("Unable to connect to master server"); - return; - } - - json_t root = Json::FromString(response.body); - root = Json::AsObject(root); - this->OnRegistrationResponse(root); - }); - } - - void SendHeartbeat() - { - Http::Request request; - request.url = GetMasterServerUrl(); - request.method = Http::Method::PUT; - - json_t body = GetHeartbeatJson(); - request.body = body.dump(); - request.header["Content-Type"] = "application/json"; - - _lastHeartbeatTime = Platform::GetTicks(); - Http::DoAsync(request, [&](Http::Response response) -> void { - if (response.status != Http::Status::Ok) - { - Console::Error::WriteLine("Unable to connect to master server"); - return; - } - - json_t root = Json::FromString(response.body); - root = Json::AsObject(root); - this->OnHeartbeatResponse(root); - }); - } - - /** - * @param jsonRoot must be of JSON type object or null - * @note jsonRoot is deliberately left non-const: json_t behaviour changes when const - */ - void OnRegistrationResponse(json_t& jsonRoot) - { - Guard::Assert(jsonRoot.is_object(), "OnRegistrationResponse expects parameter jsonRoot to be object"); - - auto status = Json::GetEnum(jsonRoot["status"], MasterServerStatus::InternalError); - - if (status == MasterServerStatus::Ok) + /** + * @param jsonRoot must be of JSON type object or null + * @note jsonRoot is deliberately left non-const: json_t behaviour changes when const + */ + void OnHeartbeatResponse(json_t& jsonRoot) { - Console::WriteLine("Server successfully registered on master server"); - json_t jsonToken = jsonRoot["token"]; - if (jsonToken.is_string()) - { - _token = Json::GetString(jsonToken); - _status = AdvertiseStatus::registered; - } - } - else - { - std::string message = Json::GetString(jsonRoot["message"]); - if (message.empty()) - { - message = "Invalid response from server"; - } - Console::Error::WriteLine( - "Unable to advertise (%d): %s\n * Check that you have port forwarded %u\n * Try setting " - "advertise_address in config.ini", - status, message.c_str(), _port); + Guard::Assert(jsonRoot.is_object(), "OnHeartbeatResponse expects parameter jsonRoot to be object"); - // Hack for https://github.com/OpenRCT2/OpenRCT2/issues/6277 - // Master server may not reply correctly if using IPv6, retry forcing IPv4, - // don't wait the full timeout. - if (!_forceIPv4 && status == MasterServerStatus::InternalError) + auto status = Json::GetEnum(jsonRoot["status"], MasterServerStatus::InternalError); + if (status == MasterServerStatus::Ok) { - _forceIPv4 = true; + // Master server has successfully updated our server status + } + else if (status == MasterServerStatus::InvalidToken) + { + _status = AdvertiseStatus::unregistered; _lastAdvertiseTime = 0; - LOG_INFO("Forcing HTTP(S) over IPv4"); + Console::Error::WriteLine("Master server heartbeat failed: Invalid Token"); } } - } - /** - * @param jsonRoot must be of JSON type object or null - * @note jsonRoot is deliberately left non-const: json_t behaviour changes when const - */ - void OnHeartbeatResponse(json_t& jsonRoot) - { - Guard::Assert(jsonRoot.is_object(), "OnHeartbeatResponse expects parameter jsonRoot to be object"); - - auto status = Json::GetEnum(jsonRoot["status"], MasterServerStatus::InternalError); - if (status == MasterServerStatus::Ok) + json_t GetHeartbeatJson() { - // Master server has successfully updated our server status - } - else if (status == MasterServerStatus::InvalidToken) - { - _status = AdvertiseStatus::unregistered; - _lastAdvertiseTime = 0; - Console::Error::WriteLine("Master server heartbeat failed: Invalid Token"); - } - } + uint32_t numPlayers = NetworkGetNumVisiblePlayers(); - json_t GetHeartbeatJson() - { - uint32_t numPlayers = NetworkGetNumVisiblePlayers(); + json_t root = { + { "token", _token }, + { "players", numPlayers }, + }; - json_t root = { - { "token", _token }, - { "players", numPlayers }, - }; + const auto& gameState = getGameState(); + const auto& date = GetDate(); + json_t mapSize = { { "x", gameState.mapSize.x - 2 }, { "y", gameState.mapSize.y - 2 } }; + json_t gameInfo = { + { "mapSize", mapSize }, + { "day", date.GetMonthTicks() }, + { "month", date.GetMonthsElapsed() }, + { "guests", gameState.park.numGuestsInPark }, + { "parkValue", gameState.park.value }, + }; - const auto& gameState = getGameState(); - const auto& date = GetDate(); - json_t mapSize = { { "x", gameState.mapSize.x - 2 }, { "y", gameState.mapSize.y - 2 } }; - json_t gameInfo = { - { "mapSize", mapSize }, - { "day", date.GetMonthTicks() }, - { "month", date.GetMonthsElapsed() }, - { "guests", gameState.park.numGuestsInPark }, - { "parkValue", gameState.park.value }, - }; + if (!(gameState.park.flags & PARK_FLAGS_NO_MONEY)) + { + gameInfo["cash"] = gameState.park.cash; + } - if (!(gameState.park.flags & PARK_FLAGS_NO_MONEY)) - { - gameInfo["cash"] = gameState.park.cash; + root["gameInfo"] = gameInfo; + + return root; } - root["gameInfo"] = gameInfo; - - return root; - } - - static std::string GenerateAdvertiseKey() - { - // Generate a string of 16 random hex characters (64-integer key as a hex formatted string) - static constexpr char hexChars[] = { - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', - }; - - std::random_device rd; - std::uniform_int_distribution dist(0, static_cast(std::size(hexChars) - 1)); - - char key[17]; - for (int32_t i = 0; i < 16; i++) + static std::string GenerateAdvertiseKey() { - int32_t hexCharIndex = dist(rd); - key[i] = hexChars[hexCharIndex]; - } - key[std::size(key) - 1] = 0; - return key; - } + // Generate a string of 16 random hex characters (64-integer key as a hex formatted string) + static constexpr char hexChars[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', + }; - static std::string GetMasterServerUrl() - { - std::string result = kMasterServerURL; - if (!Config::Get().network.MasterServerUrl.empty()) - { - result = Config::Get().network.MasterServerUrl; + std::random_device rd; + std::uniform_int_distribution dist(0, static_cast(std::size(hexChars) - 1)); + + char key[17]; + for (int32_t i = 0; i < 16; i++) + { + int32_t hexCharIndex = dist(rd); + key[i] = hexChars[hexCharIndex]; + } + key[std::size(key) - 1] = 0; + return key; + } + + static std::string GetMasterServerUrl() + { + std::string result = kMasterServerURL; + if (!Config::Get().network.MasterServerUrl.empty()) + { + result = Config::Get().network.MasterServerUrl; + } + return result; } - return result; - } #endif -}; + }; -std::unique_ptr CreateServerAdvertiser(uint16_t port) -{ - return std::make_unique(port); -} + std::unique_ptr CreateServerAdvertiser(uint16_t port) + { + return std::make_unique(port); + } +} // namespace OpenRCT2::Network #endif // DISABLE_NETWORK diff --git a/src/openrct2/network/NetworkServerAdvertiser.h b/src/openrct2/network/NetworkServerAdvertiser.h index 62ab06f86e..7738df9b34 100644 --- a/src/openrct2/network/NetworkServerAdvertiser.h +++ b/src/openrct2/network/NetworkServerAdvertiser.h @@ -11,21 +11,24 @@ #include -enum class AdvertiseStatus +namespace OpenRCT2::Network { - disabled, - unregistered, - registered, -}; - -struct INetworkServerAdvertiser -{ - virtual ~INetworkServerAdvertiser() + enum class AdvertiseStatus { - } + disabled, + unregistered, + registered, + }; - virtual AdvertiseStatus GetStatus() const = 0; - virtual void Update() = 0; -}; + struct INetworkServerAdvertiser + { + virtual ~INetworkServerAdvertiser() + { + } -[[nodiscard]] std::unique_ptr CreateServerAdvertiser(uint16_t port); + virtual AdvertiseStatus GetStatus() const = 0; + virtual void Update() = 0; + }; + + [[nodiscard]] std::unique_ptr CreateServerAdvertiser(uint16_t port); +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkTypes.h b/src/openrct2/network/NetworkTypes.h index 9288a86ab5..3a488d936e 100644 --- a/src/openrct2/network/NetworkTypes.h +++ b/src/openrct2/network/NetworkTypes.h @@ -13,136 +13,139 @@ #include "../core/EnumUtils.hpp" #include "../ride/RideTypes.h" -enum +namespace OpenRCT2::Network { - SERVER_EVENT_PLAYER_JOINED, - SERVER_EVENT_PLAYER_DISCONNECTED, -}; + enum + { + SERVER_EVENT_PLAYER_JOINED, + SERVER_EVENT_PLAYER_DISCONNECTED, + }; -enum -{ - NETWORK_TICK_FLAG_CHECKSUMS = 1 << 0, -}; + enum + { + NETWORK_TICK_FLAG_CHECKSUMS = 1 << 0, + }; -enum -{ - NETWORK_MODE_NONE, - NETWORK_MODE_CLIENT, - NETWORK_MODE_SERVER -}; + enum + { + NETWORK_MODE_NONE, + NETWORK_MODE_CLIENT, + NETWORK_MODE_SERVER + }; -enum -{ - NETWORK_PLAYER_FLAG_ISSERVER = 1 << 0, -}; + enum + { + NETWORK_PLAYER_FLAG_ISSERVER = 1 << 0, + }; -enum -{ - NETWORK_STATUS_NONE, - NETWORK_STATUS_READY, - NETWORK_STATUS_CONNECTING, - NETWORK_STATUS_CONNECTED -}; + enum + { + NETWORK_STATUS_NONE, + NETWORK_STATUS_READY, + NETWORK_STATUS_CONNECTING, + NETWORK_STATUS_CONNECTED + }; -enum class NetworkAuth : int32_t -{ - None, - Requested, - Ok, - BadVersion, - BadName, - BadPassword, - VerificationFailure, - Full, - RequirePassword, - Verified, - UnknownKeyDisallowed -}; + enum class NetworkAuth : int32_t + { + None, + Requested, + Ok, + BadVersion, + BadName, + BadPassword, + VerificationFailure, + Full, + RequirePassword, + Verified, + UnknownKeyDisallowed + }; -enum class NetworkCommand : uint32_t -{ - Auth, - Map, - Chat, - Tick = 4, - PlayerList, - Ping, - PingList, - DisconnectMessage, - GameInfo, - ShowError, - GroupList, - Event, - Token, - ObjectsList, - MapRequest, - GameAction, - PlayerInfo, - RequestGameState, - GameState, - ScriptsHeader, - ScriptsData, - Heartbeat, - Max, - Invalid = static_cast(-1), -}; + enum class NetworkCommand : uint32_t + { + Auth, + Map, + Chat, + Tick = 4, + PlayerList, + Ping, + PingList, + DisconnectMessage, + GameInfo, + ShowError, + GroupList, + Event, + Token, + ObjectsList, + MapRequest, + GameAction, + PlayerInfo, + RequestGameState, + GameState, + ScriptsHeader, + ScriptsData, + Heartbeat, + Max, + Invalid = static_cast(-1), + }; -static_assert(NetworkCommand::GameInfo == static_cast(9), "Master server expects this to be 9"); + static_assert(NetworkCommand::GameInfo == static_cast(9), "Master server expects this to be 9"); -enum class NetworkServerStatus -{ - Ok, - Desynced -}; + enum class NetworkServerStatus + { + Ok, + Desynced + }; -struct NetworkServerState -{ - NetworkServerStatus state = NetworkServerStatus::Ok; - uint32_t desyncTick = 0; - uint32_t tick = 0; - uint32_t srand0 = 0; - bool gamestateSnapshotsEnabled = false; -}; + struct NetworkServerState + { + NetworkServerStatus state = NetworkServerStatus::Ok; + uint32_t desyncTick = 0; + uint32_t tick = 0; + uint32_t srand0 = 0; + bool gamestateSnapshotsEnabled = false; + }; // Structure is used for networking specific fields with meaning, // this structure can be used in combination with DataSerialiser // to provide extra details with template specialization. #pragma pack(push, 1) -template -struct NetworkObjectId -{ - NetworkObjectId(T v) - : id(v) + template + struct NetworkObjectId { - } - NetworkObjectId() - : id(T(-1)) - { - } - operator T() const - { - return id; - } - T id; -}; + NetworkObjectId(T v) + : id(v) + { + } + NetworkObjectId() + : id(T(-1)) + { + } + operator T() const + { + return id; + } + T id; + }; #pragma pack(pop) -// NOTE: When adding new types make sure to have no duplicate _TypeID's otherwise -// there is no way to specialize templates if they have the exact symbol. -using NetworkPlayerId_t = NetworkObjectId; -using NetworkCheatType_t = NetworkObjectId; + // NOTE: When adding new types make sure to have no duplicate _TypeID's otherwise + // there is no way to specialize templates if they have the exact symbol. + using NetworkPlayerId_t = NetworkObjectId; + using NetworkCheatType_t = NetworkObjectId; -enum class NetworkStatisticsGroup : uint32_t -{ - Total = 0, // Entire network traffic. - Base, // Messages such as Tick, Ping - Commands, // Command / Game actions - MapData, - Max, -}; + enum class NetworkStatisticsGroup : uint32_t + { + Total = 0, // Entire network traffic. + Base, // Messages such as Tick, Ping + Commands, // Command / Game actions + MapData, + Max, + }; -struct NetworkStats -{ - uint64_t bytesReceived[EnumValue(NetworkStatisticsGroup::Max)]; - uint64_t bytesSent[EnumValue(NetworkStatisticsGroup::Max)]; -}; + struct NetworkStats + { + uint64_t bytesReceived[EnumValue(NetworkStatisticsGroup::Max)]; + uint64_t bytesSent[EnumValue(NetworkStatisticsGroup::Max)]; + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/NetworkUser.cpp b/src/openrct2/network/NetworkUser.cpp index c7beec5dcb..b5f34de71b 100644 --- a/src/openrct2/network/NetworkUser.cpp +++ b/src/openrct2/network/NetworkUser.cpp @@ -22,198 +22,199 @@ #include -using namespace OpenRCT2; - -constexpr const utf8* kUserStoreFilename = "users.json"; - -std::unique_ptr NetworkUser::FromJson(const json_t& jsonData) +namespace OpenRCT2::Network { - Guard::Assert(jsonData.is_object(), "NetworkUser::FromJson expects parameter jsonData to be object"); + constexpr const utf8* kUserStoreFilename = "users.json"; - const std::string hash = Json::GetString(jsonData["hash"]); - const std::string name = Json::GetString(jsonData["name"]); - json_t jsonGroupId = jsonData["groupId"]; - - std::unique_ptr user = nullptr; - if (!hash.empty() && !name.empty()) + std::unique_ptr NetworkUser::FromJson(const json_t& jsonData) { - user = std::make_unique(); - user->Hash = hash; - user->Name = name; - if (jsonGroupId.is_number_integer()) + Guard::Assert(jsonData.is_object(), "NetworkUser::FromJson expects parameter jsonData to be object"); + + const std::string hash = Json::GetString(jsonData["hash"]); + const std::string name = Json::GetString(jsonData["name"]); + json_t jsonGroupId = jsonData["groupId"]; + + std::unique_ptr user = nullptr; + if (!hash.empty() && !name.empty()) { - user->GroupId = Json::GetNumber(jsonGroupId); - } - user->Remove = false; - } - return user; -} - -json_t NetworkUser::ToJson() const -{ - json_t jsonData; - jsonData["hash"] = Hash; - jsonData["name"] = Name; - - json_t jsonGroupId; - if (GroupId.has_value()) - { - jsonGroupId = *GroupId; - } - jsonData["groupId"] = jsonGroupId; - - return jsonData; -} - -void NetworkUserManager::Load() -{ - const auto path = GetStorePath(); - - if (File::Exists(path)) - { - _usersByHash.clear(); - - try - { - json_t jsonUsers = Json::ReadFromFile(path); - for (const auto& jsonUser : jsonUsers) + user = std::make_unique(); + user->Hash = hash; + user->Name = name; + if (jsonGroupId.is_number_integer()) { - if (jsonUser.is_object()) + user->GroupId = Json::GetNumber(jsonGroupId); + } + user->Remove = false; + } + return user; + } + + json_t NetworkUser::ToJson() const + { + json_t jsonData; + jsonData["hash"] = Hash; + jsonData["name"] = Name; + + json_t jsonGroupId; + if (GroupId.has_value()) + { + jsonGroupId = *GroupId; + } + jsonData["groupId"] = jsonGroupId; + + return jsonData; + } + + void NetworkUserManager::Load() + { + const auto path = GetStorePath(); + + if (File::Exists(path)) + { + _usersByHash.clear(); + + try + { + json_t jsonUsers = Json::ReadFromFile(path); + for (const auto& jsonUser : jsonUsers) { - auto networkUser = NetworkUser::FromJson(jsonUser); - if (networkUser != nullptr) + if (jsonUser.is_object()) { - _usersByHash[networkUser->Hash] = std::move(networkUser); + auto networkUser = NetworkUser::FromJson(jsonUser); + if (networkUser != nullptr) + { + _usersByHash[networkUser->Hash] = std::move(networkUser); + } } } } - } - catch (const std::exception& ex) - { - Console::Error::WriteLine("Failed to read %s as JSON. %s", path.c_str(), ex.what()); - } - } -} - -void NetworkUserManager::Save() -{ - const auto path = GetStorePath(); - - json_t jsonUsers; - try - { - if (File::Exists(path)) - { - jsonUsers = Json::ReadFromFile(path); - } - } - catch (const std::exception&) - { - } - - // Update existing users - std::unordered_set savedHashes; - for (auto it = jsonUsers.begin(); it != jsonUsers.end();) - { - json_t jsonUser = *it; - if (!jsonUser.is_object()) - { - continue; - } - std::string hashString = Json::GetString(jsonUser["hash"]); - - const auto networkUser = GetUserByHash(hashString); - if (networkUser != nullptr) - { - if (networkUser->Remove) + catch (const std::exception& ex) + { + Console::Error::WriteLine("Failed to read %s as JSON. %s", path.c_str(), ex.what()); + } + } + } + + void NetworkUserManager::Save() + { + const auto path = GetStorePath(); + + json_t jsonUsers; + try + { + if (File::Exists(path)) + { + jsonUsers = Json::ReadFromFile(path); + } + } + catch (const std::exception&) + { + } + + // Update existing users + std::unordered_set savedHashes; + for (auto it = jsonUsers.begin(); it != jsonUsers.end();) + { + json_t jsonUser = *it; + if (!jsonUser.is_object()) { - it = jsonUsers.erase(it); - // erase advances the iterator so make sure we don't do it again continue; } + std::string hashString = Json::GetString(jsonUser["hash"]); - // replace the existing element in jsonUsers - *it = networkUser->ToJson(); - savedHashes.insert(hashString); + const auto networkUser = GetUserByHash(hashString); + if (networkUser != nullptr) + { + if (networkUser->Remove) + { + it = jsonUsers.erase(it); + // erase advances the iterator so make sure we don't do it again + continue; + } + + // replace the existing element in jsonUsers + *it = networkUser->ToJson(); + savedHashes.insert(hashString); + } + + it++; } - it++; - } - - // Add new users - for (const auto& kvp : _usersByHash) - { - const auto& networkUser = kvp.second; - if (!networkUser->Remove && savedHashes.find(networkUser->Hash) == savedHashes.end()) + // Add new users + for (const auto& kvp : _usersByHash) { - jsonUsers.push_back(networkUser->ToJson()); + const auto& networkUser = kvp.second; + if (!networkUser->Remove && savedHashes.find(networkUser->Hash) == savedHashes.end()) + { + jsonUsers.push_back(networkUser->ToJson()); + } } + + Json::WriteToFile(path, jsonUsers); } - Json::WriteToFile(path, jsonUsers); -} - -void NetworkUserManager::UnsetUsersOfGroup(uint8_t groupId) -{ - for (const auto& kvp : _usersByHash) + void NetworkUserManager::UnsetUsersOfGroup(uint8_t groupId) { - auto& networkUser = kvp.second; - if (networkUser->GroupId.has_value() && *networkUser->GroupId == groupId) + for (const auto& kvp : _usersByHash) { - networkUser->GroupId = std::nullopt; + auto& networkUser = kvp.second; + if (networkUser->GroupId.has_value() && *networkUser->GroupId == groupId) + { + networkUser->GroupId = std::nullopt; + } } } -} -void NetworkUserManager::RemoveUser(const std::string& hash) -{ - NetworkUser* networkUser = const_cast(GetUserByHash(hash)); - if (networkUser != nullptr) + void NetworkUserManager::RemoveUser(const std::string& hash) { - networkUser->Remove = true; - } -} - -const NetworkUser* NetworkUserManager::GetUserByHash(const std::string& hash) const -{ - auto it = _usersByHash.find(hash); - if (it != _usersByHash.end()) - { - return it->second.get(); - } - return nullptr; -} - -const NetworkUser* NetworkUserManager::GetUserByName(const std::string& name) const -{ - for (const auto& kvp : _usersByHash) - { - const auto& networkUser = kvp.second; - if (String::iequals(name, networkUser->Name)) + NetworkUser* networkUser = const_cast(GetUserByHash(hash)); + if (networkUser != nullptr) { - return networkUser.get(); + networkUser->Remove = true; } } - return nullptr; -} -NetworkUser* NetworkUserManager::GetOrAddUser(const std::string& hash) -{ - NetworkUser* networkUser = const_cast(GetUserByHash(hash)); - if (networkUser == nullptr) + const NetworkUser* NetworkUserManager::GetUserByHash(const std::string& hash) const { - auto newNetworkUser = std::make_unique(); - newNetworkUser->Hash = hash; - networkUser = newNetworkUser.get(); - _usersByHash[hash] = std::move(newNetworkUser); + auto it = _usersByHash.find(hash); + if (it != _usersByHash.end()) + { + return it->second.get(); + } + return nullptr; } - return networkUser; -} -u8string NetworkUserManager::GetStorePath() -{ - auto& env = OpenRCT2::GetContext()->GetPlatformEnvironment(); - return Path::Combine(env.GetDirectoryPath(OpenRCT2::DirBase::user), kUserStoreFilename); -} + const NetworkUser* NetworkUserManager::GetUserByName(const std::string& name) const + { + for (const auto& kvp : _usersByHash) + { + const auto& networkUser = kvp.second; + if (String::iequals(name, networkUser->Name)) + { + return networkUser.get(); + } + } + return nullptr; + } + + NetworkUser* NetworkUserManager::GetOrAddUser(const std::string& hash) + { + NetworkUser* networkUser = const_cast(GetUserByHash(hash)); + if (networkUser == nullptr) + { + auto newNetworkUser = std::make_unique(); + newNetworkUser->Hash = hash; + networkUser = newNetworkUser.get(); + _usersByHash[hash] = std::move(newNetworkUser); + } + return networkUser; + } + + u8string NetworkUserManager::GetStorePath() + { + auto& env = OpenRCT2::GetContext()->GetPlatformEnvironment(); + return Path::Combine(env.GetDirectoryPath(OpenRCT2::DirBase::user), kUserStoreFilename); + } +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/NetworkUser.h b/src/openrct2/network/NetworkUser.h index f25a63e609..eb5e7f9a8c 100644 --- a/src/openrct2/network/NetworkUser.h +++ b/src/openrct2/network/NetworkUser.h @@ -16,52 +16,55 @@ #include #include -class NetworkUser final +namespace OpenRCT2::Network { -public: - std::string Hash; - std::string Name; - std::optional GroupId; - bool Remove; + class NetworkUser final + { + public: + std::string Hash; + std::string Name; + std::optional GroupId; + bool Remove; - /** - * Creates a NetworkUser object from a JSON object - * @param jsonData Must be a JSON node of type object - * @return Pointer to a new NetworkUser object - * @note jsonData is deliberately left non-const: json_t behaviour changes when const - */ - static std::unique_ptr FromJson(const json_t& jsonData); + /** + * Creates a NetworkUser object from a JSON object + * @param jsonData Must be a JSON node of type object + * @return Pointer to a new NetworkUser object + * @note jsonData is deliberately left non-const: json_t behaviour changes when const + */ + static std::unique_ptr FromJson(const json_t& jsonData); - /** - * Serialise a NetworkUser object into a JSON object - * - * @return JSON representation of the NetworkUser object - */ - json_t ToJson() const; -}; + /** + * Serialise a NetworkUser object into a JSON object + * + * @return JSON representation of the NetworkUser object + */ + json_t ToJson() const; + }; -class NetworkUserManager final -{ -public: - void Load(); + class NetworkUserManager final + { + public: + void Load(); - /** - * @brief NetworkUserManager::Save - * Reads mappings from JSON, updates them in-place and saves JSON. - * - * Useful for retaining custom entries in JSON file. - */ - void Save(); + /** + * @brief NetworkUserManager::Save + * Reads mappings from JSON, updates them in-place and saves JSON. + * + * Useful for retaining custom entries in JSON file. + */ + void Save(); - void UnsetUsersOfGroup(uint8_t groupId); - void RemoveUser(const std::string& hash); + void UnsetUsersOfGroup(uint8_t groupId); + void RemoveUser(const std::string& hash); - const NetworkUser* GetUserByHash(const std::string& hash) const; - const NetworkUser* GetUserByName(const std::string& name) const; - NetworkUser* GetOrAddUser(const std::string& hash); + const NetworkUser* GetUserByHash(const std::string& hash) const; + const NetworkUser* GetUserByName(const std::string& name) const; + NetworkUser* GetOrAddUser(const std::string& hash); -private: - std::unordered_map> _usersByHash; + private: + std::unordered_map> _usersByHash; - static u8string GetStorePath(); -}; + static u8string GetStorePath(); + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/ServerList.cpp b/src/openrct2/network/ServerList.cpp index a96e3a926a..8eb3743759 100644 --- a/src/openrct2/network/ServerList.cpp +++ b/src/openrct2/network/ServerList.cpp @@ -31,412 +31,415 @@ #include #include -using namespace OpenRCT2; - -int32_t ServerListEntry::CompareTo(const ServerListEntry& other) const +namespace OpenRCT2::Network { - const auto& a = *this; - const auto& b = other; - - if (a.Favourite != b.Favourite) + int32_t ServerListEntry::CompareTo(const ServerListEntry& other) const { - return a.Favourite ? -1 : 1; - } + const auto& a = *this; + const auto& b = other; - if (a.Local != b.Local) - { - return a.Local ? -1 : 1; - } - - bool serverACompatible = a.Version == NetworkGetVersion(); - bool serverBCompatible = b.Version == NetworkGetVersion(); - if (serverACompatible != serverBCompatible) - { - return serverACompatible ? -1 : 1; - } - - if (a.RequiresPassword != b.RequiresPassword) - { - return a.RequiresPassword ? 1 : -1; - } - - if (a.Players != b.Players) - { - return a.Players > b.Players ? -1 : 1; - } - - return String::compare(a.Name, b.Name, true); -} - -bool ServerListEntry::IsVersionValid() const noexcept -{ - return Version.empty() || Version == NetworkGetVersion(); -} - -std::optional ServerListEntry::FromJson(json_t& server) -{ - Guard::Assert(server.is_object(), "ServerListEntry::FromJson expects parameter server to be object"); - - const auto port = Json::GetNumber(server["port"]); - const auto name = Json::GetString(server["name"]); - const auto description = Json::GetString(server["description"]); - const auto requiresPassword = Json::GetBoolean(server["requiresPassword"]); - const auto version = Json::GetString(server["version"]); - const auto players = Json::GetNumber(server["players"]); - const auto maxPlayers = Json::GetNumber(server["maxPlayers"]); - std::string ip; - // if server["ip"] or server["ip"]["v4"] are values, this will throw an exception, so check first - if (server["ip"].is_object() && server["ip"]["v4"].is_array()) - { - ip = Json::GetString(server["ip"]["v4"][0]); - } - - if (name.empty() || version.empty()) - { - LOG_VERBOSE("Cowardly refusing to add server without name or version specified."); - - return std::nullopt; - } - - ServerListEntry entry; - - entry.Address = ip + ":" + std::to_string(port); - entry.Name = name; - entry.Description = description; - entry.Version = version; - entry.RequiresPassword = requiresPassword; - entry.Players = players; - entry.MaxPlayers = maxPlayers; - - return entry; -} - -void ServerList::Sort() -{ - _serverEntries.erase( - std::unique( - _serverEntries.begin(), _serverEntries.end(), - [](const ServerListEntry& a, const ServerListEntry& b) { - if (a.Favourite == b.Favourite) - { - return String::iequals(a.Address, b.Address); - } - return false; - }), - _serverEntries.end()); - std::sort(_serverEntries.begin(), _serverEntries.end(), [](const ServerListEntry& a, const ServerListEntry& b) { - return a.CompareTo(b) < 0; - }); -} - -ServerListEntry& ServerList::GetServer(size_t index) -{ - return _serverEntries[index]; -} - -size_t ServerList::GetCount() const -{ - return _serverEntries.size(); -} - -void ServerList::Add(const ServerListEntry& entry) -{ - _serverEntries.push_back(entry); - Sort(); -} - -void ServerList::AddRange(const std::vector& entries) -{ - _serverEntries.insert(_serverEntries.end(), entries.begin(), entries.end()); - Sort(); -} - -void ServerList::AddOrUpdateRange(const std::vector& entries) -{ - for (auto& existsEntry : _serverEntries) - { - auto match = std::find_if( - entries.begin(), entries.end(), [&](const ServerListEntry& entry) { return existsEntry.Address == entry.Address; }); - if (match != entries.end()) + if (a.Favourite != b.Favourite) { - // Keep favourites - auto fav = existsEntry.Favourite; - - existsEntry = *match; - existsEntry.Favourite = fav; + return a.Favourite ? -1 : 1; } + + if (a.Local != b.Local) + { + return a.Local ? -1 : 1; + } + + bool serverACompatible = a.Version == NetworkGetVersion(); + bool serverBCompatible = b.Version == NetworkGetVersion(); + if (serverACompatible != serverBCompatible) + { + return serverACompatible ? -1 : 1; + } + + if (a.RequiresPassword != b.RequiresPassword) + { + return a.RequiresPassword ? 1 : -1; + } + + if (a.Players != b.Players) + { + return a.Players > b.Players ? -1 : 1; + } + + return String::compare(a.Name, b.Name, true); } - std::vector newServers; - std::copy_if(entries.begin(), entries.end(), std::back_inserter(newServers), [this](const ServerListEntry& entry) { - return std::find_if( - _serverEntries.begin(), _serverEntries.end(), - [&](const ServerListEntry& existsEntry) { return existsEntry.Address == entry.Address; }) - == _serverEntries.end(); - }); - - AddRange(newServers); -} - -void ServerList::Clear() noexcept -{ - _serverEntries.clear(); -} - -std::vector ServerList::ReadFavourites() const -{ - LOG_VERBOSE("server_list_read(...)"); - std::vector entries; - try + bool ServerListEntry::IsVersionValid() const noexcept { - auto& env = GetContext()->GetPlatformEnvironment(); - auto path = env.GetFilePath(PathId::networkServers); - if (File::Exists(path)) + return Version.empty() || Version == NetworkGetVersion(); + } + + std::optional ServerListEntry::FromJson(json_t& server) + { + Guard::Assert(server.is_object(), "ServerListEntry::FromJson expects parameter server to be object"); + + const auto port = Json::GetNumber(server["port"]); + const auto name = Json::GetString(server["name"]); + const auto description = Json::GetString(server["description"]); + const auto requiresPassword = Json::GetBoolean(server["requiresPassword"]); + const auto version = Json::GetString(server["version"]); + const auto players = Json::GetNumber(server["players"]); + const auto maxPlayers = Json::GetNumber(server["maxPlayers"]); + std::string ip; + // if server["ip"] or server["ip"]["v4"] are values, this will throw an exception, so check first + if (server["ip"].is_object() && server["ip"]["v4"].is_array()) { - auto fs = FileStream(path, FileMode::open); - auto numEntries = fs.ReadValue(); - for (size_t i = 0; i < numEntries; i++) + ip = Json::GetString(server["ip"]["v4"][0]); + } + + if (name.empty() || version.empty()) + { + LOG_VERBOSE("Cowardly refusing to add server without name or version specified."); + + return std::nullopt; + } + + ServerListEntry entry; + + entry.Address = ip + ":" + std::to_string(port); + entry.Name = name; + entry.Description = description; + entry.Version = version; + entry.RequiresPassword = requiresPassword; + entry.Players = players; + entry.MaxPlayers = maxPlayers; + + return entry; + } + + void ServerList::Sort() + { + _serverEntries.erase( + std::unique( + _serverEntries.begin(), _serverEntries.end(), + [](const ServerListEntry& a, const ServerListEntry& b) { + if (a.Favourite == b.Favourite) + { + return String::iequals(a.Address, b.Address); + } + return false; + }), + _serverEntries.end()); + std::sort(_serverEntries.begin(), _serverEntries.end(), [](const ServerListEntry& a, const ServerListEntry& b) { + return a.CompareTo(b) < 0; + }); + } + + ServerListEntry& ServerList::GetServer(size_t index) + { + return _serverEntries[index]; + } + + size_t ServerList::GetCount() const + { + return _serverEntries.size(); + } + + void ServerList::Add(const ServerListEntry& entry) + { + _serverEntries.push_back(entry); + Sort(); + } + + void ServerList::AddRange(const std::vector& entries) + { + _serverEntries.insert(_serverEntries.end(), entries.begin(), entries.end()); + Sort(); + } + + void ServerList::AddOrUpdateRange(const std::vector& entries) + { + for (auto& existsEntry : _serverEntries) + { + auto match = std::find_if(entries.begin(), entries.end(), [&](const ServerListEntry& entry) { + return existsEntry.Address == entry.Address; + }); + if (match != entries.end()) { - ServerListEntry serverInfo; - serverInfo.Address = fs.ReadString(); - serverInfo.Name = fs.ReadString(); - serverInfo.RequiresPassword = false; - serverInfo.Description = fs.ReadString(); - serverInfo.Version.clear(); - serverInfo.Favourite = true; - serverInfo.Players = 0; - serverInfo.MaxPlayers = 0; - entries.push_back(std::move(serverInfo)); + // Keep favourites + auto fav = existsEntry.Favourite; + + existsEntry = *match; + existsEntry.Favourite = fav; } } + + std::vector newServers; + std::copy_if(entries.begin(), entries.end(), std::back_inserter(newServers), [this](const ServerListEntry& entry) { + return std::find_if( + _serverEntries.begin(), _serverEntries.end(), + [&](const ServerListEntry& existsEntry) { return existsEntry.Address == entry.Address; }) + == _serverEntries.end(); + }); + + AddRange(newServers); } - catch (const std::exception& e) + + void ServerList::Clear() noexcept { - LOG_ERROR("Unable to read server list: %s", e.what()); - entries = std::vector(); + _serverEntries.clear(); } - return entries; -} -void ServerList::ReadAndAddFavourites() -{ - _serverEntries.erase( - std::remove_if( - _serverEntries.begin(), _serverEntries.end(), [](const ServerListEntry& entry) { return entry.Favourite; }), - _serverEntries.end()); - auto entries = ReadFavourites(); - AddRange(entries); -} - -void ServerList::WriteFavourites() const -{ - // Save just favourite servers - std::vector favouriteServers; - std::copy_if( - _serverEntries.begin(), _serverEntries.end(), std::back_inserter(favouriteServers), - [](const ServerListEntry& entry) { return entry.Favourite; }); - WriteFavourites(favouriteServers); -} - -bool ServerList::WriteFavourites(const std::vector& entries) const -{ - LOG_VERBOSE("server_list_write(%d, 0x%p)", entries.size(), entries.data()); - - auto& env = GetContext()->GetPlatformEnvironment(); - auto path = Path::Combine(env.GetDirectoryPath(DirBase::user), u8"servers.cfg"); - - try + std::vector ServerList::ReadFavourites() const { - auto fs = FileStream(path, FileMode::write); - fs.WriteValue(static_cast(entries.size())); - for (const auto& entry : entries) - { - fs.WriteString(entry.Address); - fs.WriteString(entry.Name); - fs.WriteString(entry.Description); - } - return true; - } - catch (const std::exception& e) - { - LOG_ERROR("Unable to write server list: %s", e.what()); - return false; - } -} - -std::future> ServerList::FetchLocalServerListAsync(const INetworkEndpoint& broadcastEndpoint) const -{ - auto broadcastAddress = broadcastEndpoint.GetHostname(); - return std::async(std::launch::async, [broadcastAddress] { - constexpr auto kReceiveDelayInMs = 10; - constexpr auto kReceiveWaitInMs = 2000; - - std::string_view msg = kNetworkLanBroadcastMsg; - auto udpSocket = CreateUdpSocket(); - - LOG_VERBOSE("Broadcasting %zu bytes to the LAN (%s)", msg.size(), broadcastAddress.c_str()); - auto len = udpSocket->SendData(broadcastAddress, kNetworkLanBroadcastPort, msg.data(), msg.size()); - if (len != msg.size()) - { - throw std::runtime_error("Unable to broadcast server query."); - } - + LOG_VERBOSE("server_list_read(...)"); std::vector entries; - for (int i = 0; i < (kReceiveWaitInMs / kReceiveDelayInMs); i++) + try { - try + auto& env = GetContext()->GetPlatformEnvironment(); + auto path = env.GetFilePath(PathId::networkServers); + if (File::Exists(path)) { - // Start with initialised buffer in case we receive a non-terminated string - char buffer[1024]{}; - size_t recievedLen{}; - std::unique_ptr endpoint; - auto p = udpSocket->ReceiveData(buffer, sizeof(buffer) - 1, &recievedLen, &endpoint); - if (p == NetworkReadPacket::Success) + auto fs = FileStream(path, FileMode::open); + auto numEntries = fs.ReadValue(); + for (size_t i = 0; i < numEntries; i++) { - auto sender = endpoint->GetHostname(); - LOG_VERBOSE("Received %zu bytes back from %s", recievedLen, sender.c_str()); - auto jinfo = Json::FromString(std::string_view(buffer)); + ServerListEntry serverInfo; + serverInfo.Address = fs.ReadString(); + serverInfo.Name = fs.ReadString(); + serverInfo.RequiresPassword = false; + serverInfo.Description = fs.ReadString(); + serverInfo.Version.clear(); + serverInfo.Favourite = true; + serverInfo.Players = 0; + serverInfo.MaxPlayers = 0; + entries.push_back(std::move(serverInfo)); + } + } + } + catch (const std::exception& e) + { + LOG_ERROR("Unable to read server list: %s", e.what()); + entries = std::vector(); + } + return entries; + } - if (jinfo.is_object()) + void ServerList::ReadAndAddFavourites() + { + _serverEntries.erase( + std::remove_if( + _serverEntries.begin(), _serverEntries.end(), [](const ServerListEntry& entry) { return entry.Favourite; }), + _serverEntries.end()); + auto entries = ReadFavourites(); + AddRange(entries); + } + + void ServerList::WriteFavourites() const + { + // Save just favourite servers + std::vector favouriteServers; + std::copy_if( + _serverEntries.begin(), _serverEntries.end(), std::back_inserter(favouriteServers), + [](const ServerListEntry& entry) { return entry.Favourite; }); + WriteFavourites(favouriteServers); + } + + bool ServerList::WriteFavourites(const std::vector& entries) const + { + LOG_VERBOSE("server_list_write(%d, 0x%p)", entries.size(), entries.data()); + + auto& env = GetContext()->GetPlatformEnvironment(); + auto path = Path::Combine(env.GetDirectoryPath(DirBase::user), u8"servers.cfg"); + + try + { + auto fs = FileStream(path, FileMode::write); + fs.WriteValue(static_cast(entries.size())); + for (const auto& entry : entries) + { + fs.WriteString(entry.Address); + fs.WriteString(entry.Name); + fs.WriteString(entry.Description); + } + return true; + } + catch (const std::exception& e) + { + LOG_ERROR("Unable to write server list: %s", e.what()); + return false; + } + } + + std::future> ServerList::FetchLocalServerListAsync( + const INetworkEndpoint& broadcastEndpoint) const + { + auto broadcastAddress = broadcastEndpoint.GetHostname(); + return std::async(std::launch::async, [broadcastAddress] { + constexpr auto kReceiveDelayInMs = 10; + constexpr auto kReceiveWaitInMs = 2000; + + std::string_view msg = kNetworkLanBroadcastMsg; + auto udpSocket = CreateUdpSocket(); + + LOG_VERBOSE("Broadcasting %zu bytes to the LAN (%s)", msg.size(), broadcastAddress.c_str()); + auto len = udpSocket->SendData(broadcastAddress, kNetworkLanBroadcastPort, msg.data(), msg.size()); + if (len != msg.size()) + { + throw std::runtime_error("Unable to broadcast server query."); + } + + std::vector entries; + for (int i = 0; i < (kReceiveWaitInMs / kReceiveDelayInMs); i++) + { + try + { + // Start with initialised buffer in case we receive a non-terminated string + char buffer[1024]{}; + size_t recievedLen{}; + std::unique_ptr endpoint; + auto p = udpSocket->ReceiveData(buffer, sizeof(buffer) - 1, &recievedLen, &endpoint); + if (p == NetworkReadPacket::Success) { - jinfo["ip"] = { { "v4", { sender } } }; + auto sender = endpoint->GetHostname(); + LOG_VERBOSE("Received %zu bytes back from %s", recievedLen, sender.c_str()); + auto jinfo = Json::FromString(std::string_view(buffer)); - auto entry = ServerListEntry::FromJson(jinfo); - if (entry.has_value()) + if (jinfo.is_object()) { - (*entry).Local = true; - entries.push_back(std::move(*entry)); + jinfo["ip"] = { { "v4", { sender } } }; + + auto entry = ServerListEntry::FromJson(jinfo); + if (entry.has_value()) + { + (*entry).Local = true; + entries.push_back(std::move(*entry)); + } } } } + catch (const std::exception& e) + { + LOG_WARNING("Error receiving data: %s", e.what()); + } + Platform::Sleep(kReceiveDelayInMs); } - catch (const std::exception& e) + return entries; + }); + } + + std::future> ServerList::FetchLocalServerListAsync() const + { + return std::async(std::launch::async, [&] { + // Get all possible LAN broadcast addresses + auto broadcastEndpoints = GetBroadcastAddresses(); + + // Spin off a fetch for each broadcast address + std::vector>> futures; + for (const auto& broadcastEndpoint : broadcastEndpoints) { - LOG_WARNING("Error receiving data: %s", e.what()); + auto f = FetchLocalServerListAsync(*broadcastEndpoint); + futures.push_back(std::move(f)); } - Platform::Sleep(kReceiveDelayInMs); - } - return entries; - }); -} -std::future> ServerList::FetchLocalServerListAsync() const -{ - return std::async(std::launch::async, [&] { - // Get all possible LAN broadcast addresses - auto broadcastEndpoints = GetBroadcastAddresses(); + // Wait and merge all results + std::vector mergedEntries; + for (auto& f : futures) + { + try + { + auto entries = f.get(); + mergedEntries.insert(mergedEntries.begin(), entries.begin(), entries.end()); + } + catch (...) + { + // Ignore any exceptions from a particular broadcast fetch + } + } + return mergedEntries; + }); + } - // Spin off a fetch for each broadcast address - std::vector>> futures; - for (const auto& broadcastEndpoint : broadcastEndpoints) + std::future> ServerList::FetchOnlineServerListAsync() const + { + #ifdef DISABLE_HTTP + return {}; + #else + + auto p = std::make_shared>>(); + auto f = p->get_future(); + + std::string masterServerUrl = kMasterServerURL; + if (!Config::Get().network.MasterServerUrl.empty()) { - auto f = FetchLocalServerListAsync(*broadcastEndpoint); - futures.push_back(std::move(f)); + masterServerUrl = Config::Get().network.MasterServerUrl; } - // Wait and merge all results - std::vector mergedEntries; - for (auto& f : futures) - { + Http::Request request; + request.url = std::move(masterServerUrl); + request.method = Http::Method::GET; + request.header["Accept"] = "application/json"; + Http::DoAsync(request, [p](Http::Response& response) -> void { + json_t root; try { - auto entries = f.get(); - mergedEntries.insert(mergedEntries.begin(), entries.begin(), entries.end()); + if (response.status != Http::Status::Ok) + { + throw MasterServerException(STR_SERVER_LIST_NO_CONNECTION); + } + + root = Json::FromString(response.body); + if (root.is_object()) + { + auto jsonStatus = root["status"]; + if (!jsonStatus.is_number_integer()) + { + throw MasterServerException(STR_SERVER_LIST_INVALID_RESPONSE_JSON_NUMBER); + } + + auto status = Json::GetNumber(jsonStatus); + if (status != 200) + { + throw MasterServerException(STR_SERVER_LIST_MASTER_SERVER_FAILED); + } + + auto jServers = root["servers"]; + if (!jServers.is_array()) + { + throw MasterServerException(STR_SERVER_LIST_INVALID_RESPONSE_JSON_ARRAY); + } + + std::vector entries; + for (auto& jServer : jServers) + { + if (jServer.is_object()) + { + auto entry = ServerListEntry::FromJson(jServer); + if (entry.has_value()) + { + entries.push_back(std::move(*entry)); + } + } + } + + p->set_value(entries); + } } catch (...) { - // Ignore any exceptions from a particular broadcast fetch + p->set_exception(std::current_exception()); } - } - return mergedEntries; - }); -} - -std::future> ServerList::FetchOnlineServerListAsync() const -{ - #ifdef DISABLE_HTTP - return {}; - #else - - auto p = std::make_shared>>(); - auto f = p->get_future(); - - std::string masterServerUrl = kMasterServerURL; - if (!Config::Get().network.MasterServerUrl.empty()) - { - masterServerUrl = Config::Get().network.MasterServerUrl; + }); + return f; + #endif } - Http::Request request; - request.url = std::move(masterServerUrl); - request.method = Http::Method::GET; - request.header["Accept"] = "application/json"; - Http::DoAsync(request, [p](Http::Response& response) -> void { - json_t root; - try - { - if (response.status != Http::Status::Ok) - { - throw MasterServerException(STR_SERVER_LIST_NO_CONNECTION); - } + uint32_t ServerList::GetTotalPlayerCount() const + { + return std::accumulate(_serverEntries.begin(), _serverEntries.end(), 0, [](uint32_t acc, const ServerListEntry& entry) { + return acc + entry.Players; + }); + } - root = Json::FromString(response.body); - if (root.is_object()) - { - auto jsonStatus = root["status"]; - if (!jsonStatus.is_number_integer()) - { - throw MasterServerException(STR_SERVER_LIST_INVALID_RESPONSE_JSON_NUMBER); - } - - auto status = Json::GetNumber(jsonStatus); - if (status != 200) - { - throw MasterServerException(STR_SERVER_LIST_MASTER_SERVER_FAILED); - } - - auto jServers = root["servers"]; - if (!jServers.is_array()) - { - throw MasterServerException(STR_SERVER_LIST_INVALID_RESPONSE_JSON_ARRAY); - } - - std::vector entries; - for (auto& jServer : jServers) - { - if (jServer.is_object()) - { - auto entry = ServerListEntry::FromJson(jServer); - if (entry.has_value()) - { - entries.push_back(std::move(*entry)); - } - } - } - - p->set_value(entries); - } - } - catch (...) - { - p->set_exception(std::current_exception()); - } - }); - return f; - #endif -} - -uint32_t ServerList::GetTotalPlayerCount() const -{ - return std::accumulate(_serverEntries.begin(), _serverEntries.end(), 0, [](uint32_t acc, const ServerListEntry& entry) { - return acc + entry.Players; - }); -} - -const char* MasterServerException::what() const noexcept -{ - static std::string localisedStatusText = LanguageGetString(StatusText); - return localisedStatusText.c_str(); -} + const char* MasterServerException::what() const noexcept + { + static std::string localisedStatusText = LanguageGetString(StatusText); + return localisedStatusText.c_str(); + } +} // namespace OpenRCT2::Network #endif diff --git a/src/openrct2/network/ServerList.h b/src/openrct2/network/ServerList.h index 9368f313bd..c57450b8d3 100644 --- a/src/openrct2/network/ServerList.h +++ b/src/openrct2/network/ServerList.h @@ -18,68 +18,71 @@ #include #include -struct INetworkEndpoint; - -struct ServerListEntry +namespace OpenRCT2::Network { - std::string Address; - std::string Name; - std::string Description; - std::string Version; - bool RequiresPassword{}; - bool Favourite{}; - uint8_t Players{}; - uint8_t MaxPlayers{}; - bool Local{}; + struct INetworkEndpoint; - int32_t CompareTo(const ServerListEntry& other) const; - bool IsVersionValid() const noexcept; - - /** - * Creates a ServerListEntry object from a JSON object - * - * @param json JSON data source - must be object type - * @return A NetworkGroup object - * @note json is deliberately left non-const: json_t behaviour changes when const - */ - static std::optional FromJson(json_t& server); -}; - -class ServerList -{ -private: - std::vector _serverEntries; - - void Sort(); - std::vector ReadFavourites() const; - bool WriteFavourites(const std::vector& entries) const; - std::future> FetchLocalServerListAsync(const INetworkEndpoint& broadcastEndpoint) const; - -public: - ServerListEntry& GetServer(size_t index); - size_t GetCount() const; - void Add(const ServerListEntry& entry); - void AddRange(const std::vector& entries); - void AddOrUpdateRange(const std::vector& entries); - void Clear() noexcept; - - void ReadAndAddFavourites(); - void WriteFavourites() const; - - std::future> FetchLocalServerListAsync() const; - std::future> FetchOnlineServerListAsync() const; - uint32_t GetTotalPlayerCount() const; -}; - -class MasterServerException : public std::exception -{ -public: - StringId StatusText; - - MasterServerException(StringId statusText) - : StatusText(statusText) + struct ServerListEntry { - } + std::string Address; + std::string Name; + std::string Description; + std::string Version; + bool RequiresPassword{}; + bool Favourite{}; + uint8_t Players{}; + uint8_t MaxPlayers{}; + bool Local{}; - const char* what() const noexcept override; -}; + int32_t CompareTo(const ServerListEntry& other) const; + bool IsVersionValid() const noexcept; + + /** + * Creates a ServerListEntry object from a JSON object + * + * @param json JSON data source - must be object type + * @return A NetworkGroup object + * @note json is deliberately left non-const: json_t behaviour changes when const + */ + static std::optional FromJson(json_t& server); + }; + + class ServerList + { + private: + std::vector _serverEntries; + + void Sort(); + std::vector ReadFavourites() const; + bool WriteFavourites(const std::vector& entries) const; + std::future> FetchLocalServerListAsync(const INetworkEndpoint& broadcastEndpoint) const; + + public: + ServerListEntry& GetServer(size_t index); + size_t GetCount() const; + void Add(const ServerListEntry& entry); + void AddRange(const std::vector& entries); + void AddOrUpdateRange(const std::vector& entries); + void Clear() noexcept; + + void ReadAndAddFavourites(); + void WriteFavourites() const; + + std::future> FetchLocalServerListAsync() const; + std::future> FetchOnlineServerListAsync() const; + uint32_t GetTotalPlayerCount() const; + }; + + class MasterServerException : public std::exception + { + public: + StringId StatusText; + + MasterServerException(StringId statusText) + : StatusText(statusText) + { + } + + const char* what() const noexcept override; + }; +} // namespace OpenRCT2::Network diff --git a/src/openrct2/network/Socket.cpp b/src/openrct2/network/Socket.cpp index 063b0ffe1a..6fc90cc7aa 100644 --- a/src/openrct2/network/Socket.cpp +++ b/src/openrct2/network/Socket.cpp @@ -73,354 +73,243 @@ #include "Socket.h" -constexpr auto kConnectTimeout = std::chrono::milliseconds(3000); +namespace OpenRCT2::Network +{ + constexpr auto kConnectTimeout = std::chrono::milliseconds(3000); // RAII WSA initialisation needed for Windows #ifdef _WIN32 -class WSA -{ -private: - bool _isInitialised{}; - -public: - bool IsInitialised() const noexcept + class WSA { - return _isInitialised; - } + private: + bool _isInitialised{}; - bool Initialise() - { - if (!_isInitialised) + public: + bool IsInitialised() const noexcept { - LOG_VERBOSE("WSAStartup()"); - WSADATA wsa_data; - if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0) + return _isInitialised; + } + + bool Initialise() + { + if (!_isInitialised) { - LOG_ERROR("Unable to initialise winsock."); + LOG_VERBOSE("WSAStartup()"); + WSADATA wsa_data; + if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0) + { + LOG_ERROR("Unable to initialise winsock."); + return false; + } + _isInitialised = true; + } + return true; + } + + ~WSA() noexcept + { + if (_isInitialised) + { + LOG_VERBOSE("WSACleanup()"); + WSACleanup(); + _isInitialised = false; + } + } + }; + + static bool InitialiseWSA() + { + static WSA wsa; + return wsa.Initialise(); + } + #else + static bool InitialiseWSA() + { + return true; + } + #endif + + class SocketException : public std::runtime_error + { + public: + explicit SocketException(const std::string& message) + : std::runtime_error(message) + { + } + }; + + class NetworkEndpoint final : public INetworkEndpoint + { + private: + sockaddr _address{}; + socklen_t _addressLen{}; + + public: + NetworkEndpoint() noexcept = default; + + NetworkEndpoint(const sockaddr* address, socklen_t addressLen) + { + std::memcpy(&_address, address, addressLen); + _addressLen = addressLen; + } + + constexpr const sockaddr& GetAddress() const noexcept + { + return _address; + } + + constexpr socklen_t GetAddressLen() const noexcept + { + return _addressLen; + } + + int32_t GetPort() const + { + if (_address.sa_family == AF_INET) + { + return reinterpret_cast(&_address)->sin_port; + } + + return reinterpret_cast(&_address)->sin6_port; + } + + std::string GetHostname() const override + { + char hostname[256]{}; + int res = getnameinfo(&_address, _addressLen, hostname, sizeof(hostname), nullptr, 0, NI_NUMERICHOST); + if (res == 0) + { + return hostname; + } + return {}; + } + }; + + class Socket + { + protected: + static bool ResolveAddress(const std::string& address, uint16_t port, sockaddr_storage* ss, socklen_t* ss_len) + { + return ResolveAddress(AF_UNSPEC, address, port, ss, ss_len); + } + + static bool ResolveAddressIPv4(const std::string& address, uint16_t port, sockaddr_storage* ss, socklen_t* ss_len) + { + return ResolveAddress(AF_INET, address, port, ss, ss_len); + } + + static bool SetNonBlocking(SOCKET socket, bool on) + { + #ifdef _WIN32 + u_long nonBlocking = on; + return ioctlsocket(socket, FIONBIO, &nonBlocking) == 0; + #else + int32_t flags = fcntl(socket, F_GETFL, 0); + return fcntl(socket, F_SETFL, on ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK)) == 0; + #endif + } + + static bool SetOption(SOCKET socket, int32_t a, int32_t b, bool value) + { + int32_t ivalue = value ? 1 : 0; + return setsockopt(socket, a, b, reinterpret_cast(&ivalue), sizeof(ivalue)) == 0; + } + + private: + static bool ResolveAddress( + int32_t family, const std::string& address, uint16_t port, sockaddr_storage* ss, socklen_t* ss_len) + { + std::string serviceName = std::to_string(port); + + addrinfo hints = {}; + hints.ai_family = family; + if (address.empty()) + { + hints.ai_flags = AI_PASSIVE; + } + + addrinfo* result = nullptr; + int errorcode = getaddrinfo(address.empty() ? nullptr : address.c_str(), serviceName.c_str(), &hints, &result); + if (errorcode != 0) + { + LOG_ERROR("Resolving address failed: Code %d.", errorcode); + LOG_ERROR("Resolution error message: %s.", gai_strerror(errorcode)); return false; } - _isInitialised = true; - } - return true; - } - ~WSA() noexcept - { - if (_isInitialised) - { - LOG_VERBOSE("WSACleanup()"); - WSACleanup(); - _isInitialised = false; - } - } -}; - -static bool InitialiseWSA() -{ - static WSA wsa; - return wsa.Initialise(); -} - #else -static bool InitialiseWSA() -{ - return true; -} - #endif - -class SocketException : public std::runtime_error -{ -public: - explicit SocketException(const std::string& message) - : std::runtime_error(message) - { - } -}; - -class NetworkEndpoint final : public INetworkEndpoint -{ -private: - sockaddr _address{}; - socklen_t _addressLen{}; - -public: - NetworkEndpoint() noexcept = default; - - NetworkEndpoint(const sockaddr* address, socklen_t addressLen) - { - std::memcpy(&_address, address, addressLen); - _addressLen = addressLen; - } - - constexpr const sockaddr& GetAddress() const noexcept - { - return _address; - } - - constexpr socklen_t GetAddressLen() const noexcept - { - return _addressLen; - } - - int32_t GetPort() const - { - if (_address.sa_family == AF_INET) - { - return reinterpret_cast(&_address)->sin_port; - } - - return reinterpret_cast(&_address)->sin6_port; - } - - std::string GetHostname() const override - { - char hostname[256]{}; - int res = getnameinfo(&_address, _addressLen, hostname, sizeof(hostname), nullptr, 0, NI_NUMERICHOST); - if (res == 0) - { - return hostname; - } - return {}; - } -}; - -class Socket -{ -protected: - static bool ResolveAddress(const std::string& address, uint16_t port, sockaddr_storage* ss, socklen_t* ss_len) - { - return ResolveAddress(AF_UNSPEC, address, port, ss, ss_len); - } - - static bool ResolveAddressIPv4(const std::string& address, uint16_t port, sockaddr_storage* ss, socklen_t* ss_len) - { - return ResolveAddress(AF_INET, address, port, ss, ss_len); - } - - static bool SetNonBlocking(SOCKET socket, bool on) - { - #ifdef _WIN32 - u_long nonBlocking = on; - return ioctlsocket(socket, FIONBIO, &nonBlocking) == 0; - #else - int32_t flags = fcntl(socket, F_GETFL, 0); - return fcntl(socket, F_SETFL, on ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK)) == 0; - #endif - } - - static bool SetOption(SOCKET socket, int32_t a, int32_t b, bool value) - { - int32_t ivalue = value ? 1 : 0; - return setsockopt(socket, a, b, reinterpret_cast(&ivalue), sizeof(ivalue)) == 0; - } - -private: - static bool ResolveAddress( - int32_t family, const std::string& address, uint16_t port, sockaddr_storage* ss, socklen_t* ss_len) - { - std::string serviceName = std::to_string(port); - - addrinfo hints = {}; - hints.ai_family = family; - if (address.empty()) - { - hints.ai_flags = AI_PASSIVE; - } - - addrinfo* result = nullptr; - int errorcode = getaddrinfo(address.empty() ? nullptr : address.c_str(), serviceName.c_str(), &hints, &result); - if (errorcode != 0) - { - LOG_ERROR("Resolving address failed: Code %d.", errorcode); - LOG_ERROR("Resolution error message: %s.", gai_strerror(errorcode)); - return false; - } - - if (result == nullptr) - { - return false; - } - - std::memcpy(ss, result->ai_addr, result->ai_addrlen); - *ss_len = static_cast(result->ai_addrlen); - freeaddrinfo(result); - return true; - } -}; - -class TcpSocket final : public ITcpSocket, protected Socket -{ -private: - std::atomic _status{ SocketStatus::Closed }; - uint16_t _listeningPort = 0; - SOCKET _socket = INVALID_SOCKET; - - std::string _ipAddress; - std::string _hostName; - std::future _connectFuture; - std::string _error; - -public: - TcpSocket() noexcept = default; - - explicit TcpSocket(SOCKET socket, std::string hostName, std::string ipAddress) noexcept - : _status(SocketStatus::Connected) - , _socket(socket) - , _ipAddress(std::move(ipAddress)) - , _hostName(std::move(hostName)) - { - } - - ~TcpSocket() override - { - if (_connectFuture.valid()) - { - _connectFuture.wait(); - } - CloseSocket(); - } - - SocketStatus GetStatus() const override - { - return _status; - } - - const char* GetError() const override - { - return _error.empty() ? nullptr : _error.c_str(); - } - - void SetNoDelay(bool noDelay) override - { - if (_socket != INVALID_SOCKET) - { - SetOption(_socket, IPPROTO_TCP, TCP_NODELAY, noDelay); - } - } - - void Listen(uint16_t port) override - { - Listen("", port); - } - - void Listen(const std::string& address, uint16_t port) override - { - if (_status != SocketStatus::Closed) - { - throw std::runtime_error("Socket not closed."); - } - - sockaddr_storage ss{}; - socklen_t ss_len; - if (!ResolveAddress(address, port, &ss, &ss_len)) - { - throw SocketException("Unable to resolve address."); - } - - // Create the listening socket - _socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP); - if (_socket == INVALID_SOCKET) - { - throw SocketException("Unable to create socket."); - } - - // Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections - if (!SetOption(_socket, IPPROTO_IPV6, IPV6_V6ONLY, false)) - { - LOG_VERBOSE("setsockopt(socket, IPV6_V6ONLY) failed: %d", LAST_SOCKET_ERROR()); - } - - if (!SetOption(_socket, SOL_SOCKET, SO_REUSEADDR, true)) - { - LOG_VERBOSE("setsockopt(socket, SO_REUSEADDR) failed: %d", LAST_SOCKET_ERROR()); - } - - try - { - // Bind to address:port and listen - if (bind(_socket, reinterpret_cast(&ss), ss_len) != 0) + if (result == nullptr) { - std::string addressOrStar = address.empty() ? "*" : address.c_str(); - throw SocketException("Unable to bind to address " + addressOrStar + ":" + std::to_string(port)); - } - if (listen(_socket, SOMAXCONN) != 0) - { - throw SocketException("Unable to listen on socket."); + return false; } - if (!SetNonBlocking(_socket, true)) - { - throw SocketException("Failed to set non-blocking mode."); - } + std::memcpy(ss, result->ai_addr, result->ai_addrlen); + *ss_len = static_cast(result->ai_addrlen); + freeaddrinfo(result); + return true; } - catch (const std::exception&) + }; + + class TcpSocket final : public ITcpSocket, protected Socket + { + private: + std::atomic _status{ SocketStatus::Closed }; + uint16_t _listeningPort = 0; + SOCKET _socket = INVALID_SOCKET; + + std::string _ipAddress; + std::string _hostName; + std::future _connectFuture; + std::string _error; + + public: + TcpSocket() noexcept = default; + + explicit TcpSocket(SOCKET socket, std::string hostName, std::string ipAddress) noexcept + : _status(SocketStatus::Connected) + , _socket(socket) + , _ipAddress(std::move(ipAddress)) + , _hostName(std::move(hostName)) { + } + + ~TcpSocket() override + { + if (_connectFuture.valid()) + { + _connectFuture.wait(); + } CloseSocket(); - throw; } - _listeningPort = port; - _status = SocketStatus::Listening; - } - - std::unique_ptr Accept() override - { - if (_status != SocketStatus::Listening) + SocketStatus GetStatus() const override { - throw std::runtime_error("Socket not listening."); + return _status; } - struct sockaddr_storage client_addr{}; - socklen_t client_len = sizeof(struct sockaddr_storage); - std::unique_ptr tcpSocket; - SOCKET socket = accept(_socket, reinterpret_cast(&client_addr), &client_len); - if (socket == INVALID_SOCKET) + const char* GetError() const override { - if (LAST_SOCKET_ERROR() != EWOULDBLOCK) + return _error.empty() ? nullptr : _error.c_str(); + } + + void SetNoDelay(bool noDelay) override + { + if (_socket != INVALID_SOCKET) { - LOG_ERROR("Failed to accept client."); + SetOption(_socket, IPPROTO_TCP, TCP_NODELAY, noDelay); } } - else + + void Listen(uint16_t port) override { - if (!SetNonBlocking(socket, true)) - { - closesocket(socket); - LOG_ERROR("Failed to set non-blocking mode."); - } - else - { - auto ipAddress = GetIpAddressFromSocket(reinterpret_cast(&client_addr)); - - char hostName[NI_MAXHOST]; - int32_t rc = getnameinfo( - reinterpret_cast(&client_addr), client_len, hostName, sizeof(hostName), nullptr, 0, - NI_NUMERICHOST | NI_NUMERICSERV); - SetNoDelay(true); - - if (rc == 0) - { - tcpSocket = std::make_unique(socket, hostName, ipAddress); - } - else - { - tcpSocket = std::make_unique(socket, "", ipAddress); - } - } - } - return tcpSocket; - } - - void Connect(const std::string& address, uint16_t port) override - { - if (_status != SocketStatus::Closed && _status != SocketStatus::Waiting) - { - throw std::runtime_error("Socket not closed."); + Listen("", port); } - try + void Listen(const std::string& address, uint16_t port) override { - // Resolve address - _status = SocketStatus::Resolving; + if (_status != SocketStatus::Closed) + { + throw std::runtime_error("Socket not closed."); + } sockaddr_storage ss{}; socklen_t ss_len; @@ -429,550 +318,666 @@ public: throw SocketException("Unable to resolve address."); } - _status = SocketStatus::Connecting; + // Create the listening socket _socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP); if (_socket == INVALID_SOCKET) { throw SocketException("Unable to create socket."); } - SetNoDelay(true); - if (!SetNonBlocking(_socket, true)) + // Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections + if (!SetOption(_socket, IPPROTO_IPV6, IPV6_V6ONLY, false)) + { + LOG_VERBOSE("setsockopt(socket, IPV6_V6ONLY) failed: %d", LAST_SOCKET_ERROR()); + } + + if (!SetOption(_socket, SOL_SOCKET, SO_REUSEADDR, true)) + { + LOG_VERBOSE("setsockopt(socket, SO_REUSEADDR) failed: %d", LAST_SOCKET_ERROR()); + } + + try + { + // Bind to address:port and listen + if (bind(_socket, reinterpret_cast(&ss), ss_len) != 0) + { + std::string addressOrStar = address.empty() ? "*" : address.c_str(); + throw SocketException("Unable to bind to address " + addressOrStar + ":" + std::to_string(port)); + } + if (listen(_socket, SOMAXCONN) != 0) + { + throw SocketException("Unable to listen on socket."); + } + + if (!SetNonBlocking(_socket, true)) + { + throw SocketException("Failed to set non-blocking mode."); + } + } + catch (const std::exception&) + { + CloseSocket(); + throw; + } + + _listeningPort = port; + _status = SocketStatus::Listening; + } + + std::unique_ptr Accept() override + { + if (_status != SocketStatus::Listening) + { + throw std::runtime_error("Socket not listening."); + } + struct sockaddr_storage client_addr{}; + socklen_t client_len = sizeof(struct sockaddr_storage); + + std::unique_ptr tcpSocket; + SOCKET socket = accept(_socket, reinterpret_cast(&client_addr), &client_len); + if (socket == INVALID_SOCKET) + { + if (LAST_SOCKET_ERROR() != EWOULDBLOCK) + { + LOG_ERROR("Failed to accept client."); + } + } + else + { + if (!SetNonBlocking(socket, true)) + { + closesocket(socket); + LOG_ERROR("Failed to set non-blocking mode."); + } + else + { + auto ipAddress = GetIpAddressFromSocket(reinterpret_cast(&client_addr)); + + char hostName[NI_MAXHOST]; + int32_t rc = getnameinfo( + reinterpret_cast(&client_addr), client_len, hostName, sizeof(hostName), nullptr, 0, + NI_NUMERICHOST | NI_NUMERICSERV); + SetNoDelay(true); + + if (rc == 0) + { + tcpSocket = std::make_unique(socket, hostName, ipAddress); + } + else + { + tcpSocket = std::make_unique(socket, "", ipAddress); + } + } + } + return tcpSocket; + } + + void Connect(const std::string& address, uint16_t port) override + { + if (_status != SocketStatus::Closed && _status != SocketStatus::Waiting) + { + throw std::runtime_error("Socket not closed."); + } + + try + { + // Resolve address + _status = SocketStatus::Resolving; + + sockaddr_storage ss{}; + socklen_t ss_len; + if (!ResolveAddress(address, port, &ss, &ss_len)) + { + throw SocketException("Unable to resolve address."); + } + + _status = SocketStatus::Connecting; + _socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP); + if (_socket == INVALID_SOCKET) + { + throw SocketException("Unable to create socket."); + } + + SetNoDelay(true); + if (!SetNonBlocking(_socket, true)) + { + throw SocketException("Failed to set non-blocking mode."); + } + + // Connect + int32_t connectResult = connect(_socket, reinterpret_cast(&ss), ss_len); + if (connectResult != SOCKET_ERROR || (LAST_SOCKET_ERROR() != EINPROGRESS && LAST_SOCKET_ERROR() != EWOULDBLOCK)) + { + throw SocketException("Failed to connect."); + } + + auto connectStartTime = std::chrono::system_clock::now(); + + int32_t error = 0; + socklen_t len = sizeof(error); + if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) + { + throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR())); + } + if (error != 0) + { + throw SocketException("Connection failed: " + std::to_string(error)); + } + + do + { + // Sleep for a bit + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + fd_set writeFD; + FD_ZERO(&writeFD); + #pragma warning(push) + #pragma warning(disable : 4548) // expression before comma has no effect; expected expression with side-effect + FD_SET(_socket, &writeFD); + #pragma warning(pop) + timeval timeout{}; + timeout.tv_sec = 0; + timeout.tv_usec = 0; + if (select(static_cast(_socket + 1), nullptr, &writeFD, nullptr, &timeout) > 0) + { + error = 0; + len = sizeof(error); + if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) + { + throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR())); + } + if (error == 0) + { + _status = SocketStatus::Connected; + return; + } + } + } while ((std::chrono::system_clock::now() - connectStartTime) < kConnectTimeout); + + // Connection request timed out + throw SocketException("Connection timed out."); + } + catch (const std::exception&) + { + CloseSocket(); + throw; + } + } + + void ConnectAsync(const std::string& address, uint16_t port) override + { + if (_status != SocketStatus::Closed) + { + throw std::runtime_error("Socket not closed."); + } + + // When connect is called, the status is set to resolving, but we want to make sure + // the status is changed before this async method exits. Otherwise, the consumer + // might think the status has closed before it started to connect. + _status = SocketStatus::Waiting; + + auto saddress = std::string(address); + std::promise barrier; + _connectFuture = barrier.get_future(); + auto thread = std::thread( + [this, saddress, port](std::promise barrier2) -> void { + try + { + Connect(saddress.c_str(), port); + } + catch (const std::exception& ex) + { + _error = std::string(ex.what()); + } + barrier2.set_value(); + }, + std::move(barrier)); + thread.detach(); + } + + void Finish() override + { + if (_status == SocketStatus::Connected) + { + shutdown(_socket, SHUT_WR); + } + } + + void Disconnect() override + { + if (_status == SocketStatus::Connected) + { + shutdown(_socket, SHUT_RDWR); + } + _status = SocketStatus::Closed; + } + + size_t SendData(const void* buffer, size_t size) override + { + if (_status != SocketStatus::Connected) + { + throw std::runtime_error("Socket not connected."); + } + + size_t totalSent = 0; + do + { + const char* bufferStart = static_cast(buffer) + totalSent; + size_t remainingSize = size - totalSent; + int32_t sentBytes = send(_socket, bufferStart, static_cast(remainingSize), FLAG_NO_PIPE); + if (sentBytes == SOCKET_ERROR) + { + return totalSent; + } + totalSent += sentBytes; + } while (totalSent < size); + return totalSent; + } + + NetworkReadPacket ReceiveData(void* buffer, size_t size, size_t* sizeReceived) override + { + if (_status != SocketStatus::Connected) + { + throw std::runtime_error("Socket not connected."); + } + + int32_t readBytes = recv(_socket, static_cast(buffer), static_cast(size), 0); + if (readBytes == 0) + { + *sizeReceived = 0; + return NetworkReadPacket::Disconnected; + } + + if (readBytes == SOCKET_ERROR) + { + *sizeReceived = 0; + #ifndef _WIN32 + // Removing the check for EAGAIN and instead relying on the values being the same allows turning on of + // -Wlogical-op warning. + // This is not true on Windows, see: + // * https://msdn.microsoft.com/en-us/library/windows/desktop/ms737828(v=vs.85).aspx + // * https://msdn.microsoft.com/en-us/library/windows/desktop/ms741580(v=vs.85).aspx + // * https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668(v=vs.85).aspx + static_assert( + EWOULDBLOCK == EAGAIN, + "Portability note: your system has different values for EWOULDBLOCK " + "and EAGAIN, please extend the condition below"); + #endif // _WIN32 + if (LAST_SOCKET_ERROR() != EWOULDBLOCK) + { + return NetworkReadPacket::Disconnected; + } + + return NetworkReadPacket::NoData; + } + + *sizeReceived = readBytes; + return NetworkReadPacket::Success; + } + + void Close() override + { + if (_connectFuture.valid()) + { + _connectFuture.wait(); + } + CloseSocket(); + } + + const char* GetHostName() const override + { + return _hostName.empty() ? nullptr : _hostName.c_str(); + } + + std::string GetIpAddress() const override + { + return _ipAddress; + } + + private: + void CloseSocket() + { + if (_socket != INVALID_SOCKET) + { + closesocket(_socket); + _socket = INVALID_SOCKET; + } + _status = SocketStatus::Closed; + } + + std::string GetIpAddressFromSocket(const sockaddr_in* addr) const + { + std::string result; + if (addr->sin_family == AF_INET) + { + char str[INET_ADDRSTRLEN]{}; + inet_ntop(AF_INET, &addr->sin_addr, str, sizeof(str)); + result = str; + } + else if (addr->sin_family == AF_INET6) + { + auto addrv6 = reinterpret_cast(&addr); + char str[INET6_ADDRSTRLEN]{}; + inet_ntop(AF_INET6, &addrv6->sin6_addr, str, sizeof(str)); + result = str; + } + return result; + } + }; + + class UdpSocket final : public IUdpSocket, protected Socket + { + private: + SocketStatus _status = SocketStatus::Closed; + uint16_t _listeningPort = 0; + SOCKET _socket = INVALID_SOCKET; + NetworkEndpoint _endpoint; + + std::string _hostName; + std::string _error; + + public: + UdpSocket() noexcept = default; + + ~UdpSocket() override + { + CloseSocket(); + } + + SocketStatus GetStatus() const override + { + return _status; + } + + const char* GetError() const override + { + return _error.empty() ? nullptr : _error.c_str(); + } + + void Listen(uint16_t port) override + { + Listen("", port); + } + + void Listen(const std::string& address, uint16_t port) override + { + if (_status != SocketStatus::Closed) + { + throw std::runtime_error("Socket not closed."); + } + + sockaddr_storage ss{}; + socklen_t ss_len; + if (!ResolveAddressIPv4(address, port, &ss, &ss_len)) + { + throw SocketException("Unable to resolve address."); + } + + // Create the listening socket + _socket = CreateSocket(); + try + { + // Bind to address:port and listen + if (bind(_socket, reinterpret_cast(&ss), ss_len) != 0) + { + throw SocketException("Unable to bind to socket."); + } + } + catch (const std::exception&) + { + CloseSocket(); + throw; + } + + _listeningPort = port; + _status = SocketStatus::Listening; + } + + size_t SendData(const std::string& address, uint16_t port, const void* buffer, size_t size) override + { + sockaddr_storage ss{}; + socklen_t ss_len; + if (!ResolveAddressIPv4(address, port, &ss, &ss_len)) + { + throw SocketException("Unable to resolve address."); + } + NetworkEndpoint endpoint(reinterpret_cast(&ss), ss_len); + return SendData(endpoint, buffer, size); + } + + size_t SendData(const INetworkEndpoint& destination, const void* buffer, size_t size) override + { + if (_socket == INVALID_SOCKET) + { + _socket = CreateSocket(); + } + + const auto& dest = dynamic_cast(&destination); + if (dest == nullptr) + { + throw std::invalid_argument("destination is not compatible."); + } + auto ss = &dest->GetAddress(); + auto ss_len = dest->GetAddressLen(); + + if (_status != SocketStatus::Listening) + { + _endpoint = *dest; + } + + size_t totalSent = 0; + do + { + const char* bufferStart = static_cast(buffer) + totalSent; + size_t remainingSize = size - totalSent; + int32_t sentBytes = sendto( + _socket, bufferStart, static_cast(remainingSize), FLAG_NO_PIPE, static_cast(ss), + ss_len); + if (sentBytes == SOCKET_ERROR) + { + return totalSent; + } + totalSent += sentBytes; + } while (totalSent < size); + return totalSent; + } + + NetworkReadPacket ReceiveData( + void* buffer, size_t size, size_t* sizeReceived, std::unique_ptr* sender) override + { + sockaddr_in senderAddr{}; + socklen_t senderAddrLen = sizeof(sockaddr_in); + if (_status != SocketStatus::Listening) + { + senderAddrLen = _endpoint.GetAddressLen(); + std::memcpy(&senderAddr, &_endpoint.GetAddress(), senderAddrLen); + } + auto readBytes = recvfrom( + _socket, static_cast(buffer), static_cast(size), 0, reinterpret_cast(&senderAddr), + &senderAddrLen); + if (readBytes <= 0) + { + *sizeReceived = 0; + return NetworkReadPacket::NoData; + } + + *sizeReceived = readBytes; + if (sender != nullptr) + { + *sender = std::make_unique(reinterpret_cast(&senderAddr), senderAddrLen); + } + return NetworkReadPacket::Success; + } + + void Close() override + { + CloseSocket(); + } + + const char* GetHostName() const override + { + return _hostName.empty() ? nullptr : _hostName.c_str(); + } + + private: + SOCKET CreateSocket() const + { + auto sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) + { + throw SocketException("Unable to create socket."); + } + + // Enable send and receiving of broadcast messages + if (!SetOption(sock, SOL_SOCKET, SO_BROADCAST, true)) + { + LOG_VERBOSE("setsockopt(socket, SO_BROADCAST) failed: %d", LAST_SOCKET_ERROR()); + } + + // Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections + if (!SetOption(sock, IPPROTO_IPV6, IPV6_V6ONLY, false)) + { + LOG_VERBOSE("setsockopt(socket, IPV6_V6ONLY) failed: %d", LAST_SOCKET_ERROR()); + } + + if (!SetOption(sock, SOL_SOCKET, SO_REUSEADDR, true)) + { + LOG_VERBOSE("setsockopt(socket, SO_REUSEADDR) failed: %d", LAST_SOCKET_ERROR()); + } + + if (!SetNonBlocking(sock, true)) { throw SocketException("Failed to set non-blocking mode."); } - // Connect - int32_t connectResult = connect(_socket, reinterpret_cast(&ss), ss_len); - if (connectResult != SOCKET_ERROR || (LAST_SOCKET_ERROR() != EINPROGRESS && LAST_SOCKET_ERROR() != EWOULDBLOCK)) + return sock; + } + + void CloseSocket() + { + if (_socket != INVALID_SOCKET) { - throw SocketException("Failed to connect."); + closesocket(_socket); + _socket = INVALID_SOCKET; } - - auto connectStartTime = std::chrono::system_clock::now(); - - int32_t error = 0; - socklen_t len = sizeof(error); - if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) - { - throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR())); - } - if (error != 0) - { - throw SocketException("Connection failed: " + std::to_string(error)); - } - - do - { - // Sleep for a bit - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - fd_set writeFD; - FD_ZERO(&writeFD); - #pragma warning(push) - #pragma warning(disable : 4548) // expression before comma has no effect; expected expression with side-effect - FD_SET(_socket, &writeFD); - #pragma warning(pop) - timeval timeout{}; - timeout.tv_sec = 0; - timeout.tv_usec = 0; - if (select(static_cast(_socket + 1), nullptr, &writeFD, nullptr, &timeout) > 0) - { - error = 0; - len = sizeof(error); - if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) - { - throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR())); - } - if (error == 0) - { - _status = SocketStatus::Connected; - return; - } - } - } while ((std::chrono::system_clock::now() - connectStartTime) < kConnectTimeout); - - // Connection request timed out - throw SocketException("Connection timed out."); + _status = SocketStatus::Closed; } - catch (const std::exception&) - { - CloseSocket(); - throw; - } - } + }; - void ConnectAsync(const std::string& address, uint16_t port) override + std::unique_ptr CreateTcpSocket() { - if (_status != SocketStatus::Closed) - { - throw std::runtime_error("Socket not closed."); - } - - // When connect is called, the status is set to resolving, but we want to make sure - // the status is changed before this async method exits. Otherwise, the consumer - // might think the status has closed before it started to connect. - _status = SocketStatus::Waiting; - - auto saddress = std::string(address); - std::promise barrier; - _connectFuture = barrier.get_future(); - auto thread = std::thread( - [this, saddress, port](std::promise barrier2) -> void { - try - { - Connect(saddress.c_str(), port); - } - catch (const std::exception& ex) - { - _error = std::string(ex.what()); - } - barrier2.set_value(); - }, - std::move(barrier)); - thread.detach(); + InitialiseWSA(); + return std::make_unique(); } - void Finish() override + std::unique_ptr CreateUdpSocket() { - if (_status == SocketStatus::Connected) - { - shutdown(_socket, SHUT_WR); - } + InitialiseWSA(); + return std::make_unique(); } - void Disconnect() override - { - if (_status == SocketStatus::Connected) - { - shutdown(_socket, SHUT_RDWR); - } - _status = SocketStatus::Closed; - } - - size_t SendData(const void* buffer, size_t size) override - { - if (_status != SocketStatus::Connected) - { - throw std::runtime_error("Socket not connected."); - } - - size_t totalSent = 0; - do - { - const char* bufferStart = static_cast(buffer) + totalSent; - size_t remainingSize = size - totalSent; - int32_t sentBytes = send(_socket, bufferStart, static_cast(remainingSize), FLAG_NO_PIPE); - if (sentBytes == SOCKET_ERROR) - { - return totalSent; - } - totalSent += sentBytes; - } while (totalSent < size); - return totalSent; - } - - NetworkReadPacket ReceiveData(void* buffer, size_t size, size_t* sizeReceived) override - { - if (_status != SocketStatus::Connected) - { - throw std::runtime_error("Socket not connected."); - } - - int32_t readBytes = recv(_socket, static_cast(buffer), static_cast(size), 0); - if (readBytes == 0) - { - *sizeReceived = 0; - return NetworkReadPacket::Disconnected; - } - - if (readBytes == SOCKET_ERROR) - { - *sizeReceived = 0; - #ifndef _WIN32 - // Removing the check for EAGAIN and instead relying on the values being the same allows turning on of - // -Wlogical-op warning. - // This is not true on Windows, see: - // * https://msdn.microsoft.com/en-us/library/windows/desktop/ms737828(v=vs.85).aspx - // * https://msdn.microsoft.com/en-us/library/windows/desktop/ms741580(v=vs.85).aspx - // * https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668(v=vs.85).aspx - static_assert( - EWOULDBLOCK == EAGAIN, - "Portability note: your system has different values for EWOULDBLOCK " - "and EAGAIN, please extend the condition below"); - #endif // _WIN32 - if (LAST_SOCKET_ERROR() != EWOULDBLOCK) - { - return NetworkReadPacket::Disconnected; - } - - return NetworkReadPacket::NoData; - } - - *sizeReceived = readBytes; - return NetworkReadPacket::Success; - } - - void Close() override - { - if (_connectFuture.valid()) - { - _connectFuture.wait(); - } - CloseSocket(); - } - - const char* GetHostName() const override - { - return _hostName.empty() ? nullptr : _hostName.c_str(); - } - - std::string GetIpAddress() const override - { - return _ipAddress; - } - -private: - void CloseSocket() - { - if (_socket != INVALID_SOCKET) - { - closesocket(_socket); - _socket = INVALID_SOCKET; - } - _status = SocketStatus::Closed; - } - - std::string GetIpAddressFromSocket(const sockaddr_in* addr) const - { - std::string result; - if (addr->sin_family == AF_INET) - { - char str[INET_ADDRSTRLEN]{}; - inet_ntop(AF_INET, &addr->sin_addr, str, sizeof(str)); - result = str; - } - else if (addr->sin_family == AF_INET6) - { - auto addrv6 = reinterpret_cast(&addr); - char str[INET6_ADDRSTRLEN]{}; - inet_ntop(AF_INET6, &addrv6->sin6_addr, str, sizeof(str)); - result = str; - } - return result; - } -}; - -class UdpSocket final : public IUdpSocket, protected Socket -{ -private: - SocketStatus _status = SocketStatus::Closed; - uint16_t _listeningPort = 0; - SOCKET _socket = INVALID_SOCKET; - NetworkEndpoint _endpoint; - - std::string _hostName; - std::string _error; - -public: - UdpSocket() noexcept = default; - - ~UdpSocket() override - { - CloseSocket(); - } - - SocketStatus GetStatus() const override - { - return _status; - } - - const char* GetError() const override - { - return _error.empty() ? nullptr : _error.c_str(); - } - - void Listen(uint16_t port) override - { - Listen("", port); - } - - void Listen(const std::string& address, uint16_t port) override - { - if (_status != SocketStatus::Closed) - { - throw std::runtime_error("Socket not closed."); - } - - sockaddr_storage ss{}; - socklen_t ss_len; - if (!ResolveAddressIPv4(address, port, &ss, &ss_len)) - { - throw SocketException("Unable to resolve address."); - } - - // Create the listening socket - _socket = CreateSocket(); - try - { - // Bind to address:port and listen - if (bind(_socket, reinterpret_cast(&ss), ss_len) != 0) - { - throw SocketException("Unable to bind to socket."); - } - } - catch (const std::exception&) - { - CloseSocket(); - throw; - } - - _listeningPort = port; - _status = SocketStatus::Listening; - } - - size_t SendData(const std::string& address, uint16_t port, const void* buffer, size_t size) override - { - sockaddr_storage ss{}; - socklen_t ss_len; - if (!ResolveAddressIPv4(address, port, &ss, &ss_len)) - { - throw SocketException("Unable to resolve address."); - } - NetworkEndpoint endpoint(reinterpret_cast(&ss), ss_len); - return SendData(endpoint, buffer, size); - } - - size_t SendData(const INetworkEndpoint& destination, const void* buffer, size_t size) override - { - if (_socket == INVALID_SOCKET) - { - _socket = CreateSocket(); - } - - const auto& dest = dynamic_cast(&destination); - if (dest == nullptr) - { - throw std::invalid_argument("destination is not compatible."); - } - auto ss = &dest->GetAddress(); - auto ss_len = dest->GetAddressLen(); - - if (_status != SocketStatus::Listening) - { - _endpoint = *dest; - } - - size_t totalSent = 0; - do - { - const char* bufferStart = static_cast(buffer) + totalSent; - size_t remainingSize = size - totalSent; - int32_t sentBytes = sendto( - _socket, bufferStart, static_cast(remainingSize), FLAG_NO_PIPE, static_cast(ss), - ss_len); - if (sentBytes == SOCKET_ERROR) - { - return totalSent; - } - totalSent += sentBytes; - } while (totalSent < size); - return totalSent; - } - - NetworkReadPacket ReceiveData( - void* buffer, size_t size, size_t* sizeReceived, std::unique_ptr* sender) override - { - sockaddr_in senderAddr{}; - socklen_t senderAddrLen = sizeof(sockaddr_in); - if (_status != SocketStatus::Listening) - { - senderAddrLen = _endpoint.GetAddressLen(); - std::memcpy(&senderAddr, &_endpoint.GetAddress(), senderAddrLen); - } - auto readBytes = recvfrom( - _socket, static_cast(buffer), static_cast(size), 0, reinterpret_cast(&senderAddr), - &senderAddrLen); - if (readBytes <= 0) - { - *sizeReceived = 0; - return NetworkReadPacket::NoData; - } - - *sizeReceived = readBytes; - if (sender != nullptr) - { - *sender = std::make_unique(reinterpret_cast(&senderAddr), senderAddrLen); - } - return NetworkReadPacket::Success; - } - - void Close() override - { - CloseSocket(); - } - - const char* GetHostName() const override - { - return _hostName.empty() ? nullptr : _hostName.c_str(); - } - -private: - SOCKET CreateSocket() const - { - auto sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - if (sock == INVALID_SOCKET) - { - throw SocketException("Unable to create socket."); - } - - // Enable send and receiving of broadcast messages - if (!SetOption(sock, SOL_SOCKET, SO_BROADCAST, true)) - { - LOG_VERBOSE("setsockopt(socket, SO_BROADCAST) failed: %d", LAST_SOCKET_ERROR()); - } - - // Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections - if (!SetOption(sock, IPPROTO_IPV6, IPV6_V6ONLY, false)) - { - LOG_VERBOSE("setsockopt(socket, IPV6_V6ONLY) failed: %d", LAST_SOCKET_ERROR()); - } - - if (!SetOption(sock, SOL_SOCKET, SO_REUSEADDR, true)) - { - LOG_VERBOSE("setsockopt(socket, SO_REUSEADDR) failed: %d", LAST_SOCKET_ERROR()); - } - - if (!SetNonBlocking(sock, true)) - { - throw SocketException("Failed to set non-blocking mode."); - } - - return sock; - } - - void CloseSocket() - { - if (_socket != INVALID_SOCKET) - { - closesocket(_socket); - _socket = INVALID_SOCKET; - } - _status = SocketStatus::Closed; - } -}; - -std::unique_ptr CreateTcpSocket() -{ - InitialiseWSA(); - return std::make_unique(); -} - -std::unique_ptr CreateUdpSocket() -{ - InitialiseWSA(); - return std::make_unique(); -} - #ifdef _WIN32 -static std::vector GetNetworkInterfaces() -{ - InitialiseWSA(); - - int sock = socket(AF_INET, SOCK_DGRAM, 0); - if (sock == -1) + static std::vector GetNetworkInterfaces() { - return {}; - } + InitialiseWSA(); - // Get all the network interfaces, requires a trial and error approach - // until we find the capacity required to store all of them. - DWORD len = 0; - size_t capacity = 16; - std::vector interfaces; - for (;;) - { - interfaces.resize(capacity); - if (WSAIoctl( - sock, SIO_GET_INTERFACE_LIST, nullptr, 0, interfaces.data(), - static_cast(capacity * sizeof(INTERFACE_INFO)), &len, nullptr, nullptr) - == 0) + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock == -1) { - break; - } - if (WSAGetLastError() != WSAEFAULT) - { - closesocket(sock); return {}; } - capacity *= 2; - } - interfaces.resize(len / sizeof(INTERFACE_INFO)); - interfaces.shrink_to_fit(); - return interfaces; -} - #endif -std::vector> GetBroadcastAddresses() -{ - std::vector> baddresses; - #ifdef _WIN32 - auto interfaces = GetNetworkInterfaces(); - for (const auto& ifo : interfaces) - { - if (ifo.iiFlags & IFF_LOOPBACK) - continue; - if (!(ifo.iiFlags & IFF_BROADCAST)) - continue; - - // iiBroadcast is unusable, because it always seems to be set to 255.255.255.255. - sockaddr_storage address{}; - memcpy(&address, &ifo.iiAddress.Address, sizeof(sockaddr)); - (reinterpret_cast(&address))->sin_addr.s_addr = ifo.iiAddress.AddressIn.sin_addr.s_addr - | ~ifo.iiNetmask.AddressIn.sin_addr.s_addr; - baddresses.push_back( - std::make_unique( - reinterpret_cast(&address), static_cast(sizeof(sockaddr)))); - } - #else - int sock = socket(AF_INET, SOCK_DGRAM, 0); - if (sock == -1) - { - return baddresses; - } - - char buf[4 * 1024]{}; - ifconf ifconfx{}; - ifconfx.ifc_len = sizeof(buf); - ifconfx.ifc_buf = buf; - if (ioctl(sock, SIOCGIFCONF, &ifconfx) == -1) - { - close(sock); - return baddresses; - } - - const char* buf_end = buf + ifconfx.ifc_len; - for (const char* p = buf; p < buf_end;) - { - auto req = reinterpret_cast(p); - if (req->ifr_addr.sa_family == AF_INET) + // Get all the network interfaces, requires a trial and error approach + // until we find the capacity required to store all of them. + DWORD len = 0; + size_t capacity = 16; + std::vector interfaces; + for (;;) { - ifreq r; - strcpy(r.ifr_name, req->ifr_name); - if (ioctl(sock, SIOCGIFFLAGS, &r) != -1 && (r.ifr_flags & IFF_BROADCAST) && ioctl(sock, SIOCGIFBRDADDR, &r) != -1) + interfaces.resize(capacity); + if (WSAIoctl( + sock, SIO_GET_INTERFACE_LIST, nullptr, 0, interfaces.data(), + static_cast(capacity * sizeof(INTERFACE_INFO)), &len, nullptr, nullptr) + == 0) { - baddresses.push_back(std::make_unique(&r.ifr_broadaddr, sizeof(sockaddr))); + break; } + if (WSAGetLastError() != WSAEFAULT) + { + closesocket(sock); + return {}; + } + capacity *= 2; } - p += sizeof(ifreq); - #if defined(AF_LINK) && !defined(SUNOS) - p += req->ifr_addr.sa_len - sizeof(struct sockaddr); - #endif + interfaces.resize(len / sizeof(INTERFACE_INFO)); + interfaces.shrink_to_fit(); + return interfaces; } - close(sock); #endif - return baddresses; -} + + std::vector> GetBroadcastAddresses() + { + std::vector> baddresses; + #ifdef _WIN32 + auto interfaces = GetNetworkInterfaces(); + for (const auto& ifo : interfaces) + { + if (ifo.iiFlags & IFF_LOOPBACK) + continue; + if (!(ifo.iiFlags & IFF_BROADCAST)) + continue; + + // iiBroadcast is unusable, because it always seems to be set to 255.255.255.255. + sockaddr_storage address{}; + memcpy(&address, &ifo.iiAddress.Address, sizeof(sockaddr)); + (reinterpret_cast(&address))->sin_addr.s_addr = ifo.iiAddress.AddressIn.sin_addr.s_addr + | ~ifo.iiNetmask.AddressIn.sin_addr.s_addr; + baddresses.push_back( + std::make_unique( + reinterpret_cast(&address), static_cast(sizeof(sockaddr)))); + } + #else + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock == -1) + { + return baddresses; + } + + char buf[4 * 1024]{}; + ifconf ifconfx{}; + ifconfx.ifc_len = sizeof(buf); + ifconfx.ifc_buf = buf; + if (ioctl(sock, SIOCGIFCONF, &ifconfx) == -1) + { + close(sock); + return baddresses; + } + + const char* buf_end = buf + ifconfx.ifc_len; + for (const char* p = buf; p < buf_end;) + { + auto req = reinterpret_cast(p); + if (req->ifr_addr.sa_family == AF_INET) + { + ifreq r; + strcpy(r.ifr_name, req->ifr_name); + if (ioctl(sock, SIOCGIFFLAGS, &r) != -1 && (r.ifr_flags & IFF_BROADCAST) + && ioctl(sock, SIOCGIFBRDADDR, &r) != -1) + { + baddresses.push_back(std::make_unique(&r.ifr_broadaddr, sizeof(sockaddr))); + } + } + p += sizeof(ifreq); + #if defined(AF_LINK) && !defined(SUNOS) + p += req->ifr_addr.sa_len - sizeof(struct sockaddr); + #endif + } + close(sock); + #endif + return baddresses; + } + +} // namespace OpenRCT2::Network namespace OpenRCT2::Convert { diff --git a/src/openrct2/network/Socket.h b/src/openrct2/network/Socket.h index d6e13b44bb..1499ace7e7 100644 --- a/src/openrct2/network/Socket.h +++ b/src/openrct2/network/Socket.h @@ -13,93 +13,96 @@ #include #include -enum class SocketStatus +namespace OpenRCT2::Network { - Closed, - Waiting, - Resolving, - Connecting, - Connected, - Listening, -}; - -enum class NetworkReadPacket : int32_t -{ - Success, - NoData, - MoreData, - Disconnected -}; - -/** - * Represents an address and port. - */ -struct INetworkEndpoint -{ - virtual ~INetworkEndpoint() + enum class SocketStatus { - } + Closed, + Waiting, + Resolving, + Connecting, + Connected, + Listening, + }; - virtual std::string GetHostname() const = 0; -}; + enum class NetworkReadPacket : int32_t + { + Success, + NoData, + MoreData, + Disconnected + }; -/** - * Represents a TCP socket / connection or listener. - */ -struct ITcpSocket -{ -public: - virtual ~ITcpSocket() = default; + /** + * Represents an address and port. + */ + struct INetworkEndpoint + { + virtual ~INetworkEndpoint() + { + } - virtual SocketStatus GetStatus() const = 0; - virtual const char* GetError() const = 0; - virtual const char* GetHostName() const = 0; - virtual std::string GetIpAddress() const = 0; + virtual std::string GetHostname() const = 0; + }; - virtual void Listen(uint16_t port) = 0; - virtual void Listen(const std::string& address, uint16_t port) = 0; - [[nodiscard]] virtual std::unique_ptr Accept() = 0; + /** + * Represents a TCP socket / connection or listener. + */ + struct ITcpSocket + { + public: + virtual ~ITcpSocket() = default; - virtual void Connect(const std::string& address, uint16_t port) = 0; - virtual void ConnectAsync(const std::string& address, uint16_t port) = 0; + virtual SocketStatus GetStatus() const = 0; + virtual const char* GetError() const = 0; + virtual const char* GetHostName() const = 0; + virtual std::string GetIpAddress() const = 0; - virtual size_t SendData(const void* buffer, size_t size) = 0; - virtual NetworkReadPacket ReceiveData(void* buffer, size_t size, size_t* sizeReceived) = 0; + virtual void Listen(uint16_t port) = 0; + virtual void Listen(const std::string& address, uint16_t port) = 0; + [[nodiscard]] virtual std::unique_ptr Accept() = 0; - virtual void SetNoDelay(bool noDelay) = 0; + virtual void Connect(const std::string& address, uint16_t port) = 0; + virtual void ConnectAsync(const std::string& address, uint16_t port) = 0; - virtual void Finish() = 0; - virtual void Disconnect() = 0; - virtual void Close() = 0; -}; + virtual size_t SendData(const void* buffer, size_t size) = 0; + virtual NetworkReadPacket ReceiveData(void* buffer, size_t size, size_t* sizeReceived) = 0; -/** - * Represents a UDP socket / listener. - */ -struct IUdpSocket -{ -public: - virtual ~IUdpSocket() = default; + virtual void SetNoDelay(bool noDelay) = 0; - virtual SocketStatus GetStatus() const = 0; - virtual const char* GetError() const = 0; - virtual const char* GetHostName() const = 0; + virtual void Finish() = 0; + virtual void Disconnect() = 0; + virtual void Close() = 0; + }; - virtual void Listen(uint16_t port) = 0; - virtual void Listen(const std::string& address, uint16_t port) = 0; + /** + * Represents a UDP socket / listener. + */ + struct IUdpSocket + { + public: + virtual ~IUdpSocket() = default; - virtual size_t SendData(const std::string& address, uint16_t port, const void* buffer, size_t size) = 0; - virtual size_t SendData(const INetworkEndpoint& destination, const void* buffer, size_t size) = 0; - virtual NetworkReadPacket ReceiveData( - void* buffer, size_t size, size_t* sizeReceived, std::unique_ptr* sender) - = 0; + virtual SocketStatus GetStatus() const = 0; + virtual const char* GetError() const = 0; + virtual const char* GetHostName() const = 0; - virtual void Close() = 0; -}; + virtual void Listen(uint16_t port) = 0; + virtual void Listen(const std::string& address, uint16_t port) = 0; -[[nodiscard]] std::unique_ptr CreateTcpSocket(); -[[nodiscard]] std::unique_ptr CreateUdpSocket(); -[[nodiscard]] std::vector> GetBroadcastAddresses(); + virtual size_t SendData(const std::string& address, uint16_t port, const void* buffer, size_t size) = 0; + virtual size_t SendData(const INetworkEndpoint& destination, const void* buffer, size_t size) = 0; + virtual NetworkReadPacket ReceiveData( + void* buffer, size_t size, size_t* sizeReceived, std::unique_ptr* sender) + = 0; + + virtual void Close() = 0; + }; + + [[nodiscard]] std::unique_ptr CreateTcpSocket(); + [[nodiscard]] std::unique_ptr CreateUdpSocket(); + [[nodiscard]] std::vector> GetBroadcastAddresses(); +} // namespace OpenRCT2::Network namespace OpenRCT2::Convert {