From 1f4f0c015c3a5e65e5ef67ffe01d980fc49e508f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=B6eh=20Matt?= <5415177+ZehMatt@users.noreply.github.com> Date: Fri, 3 Sep 2021 13:43:06 +0300 Subject: [PATCH] Fix ownership of loaded object data --- src/openrct2/object/ObjectManager.cpp | 287 ++++++++++----------- src/openrct2/object/ObjectRepository.cpp | 6 +- src/openrct2/object/ObjectRepository.h | 4 +- src/openrct2/object/SceneryGroupObject.cpp | 2 +- 4 files changed, 142 insertions(+), 157 deletions(-) diff --git a/src/openrct2/object/ObjectManager.cpp b/src/openrct2/object/ObjectManager.cpp index b72f6c0bd8..c0bcc39faa 100644 --- a/src/openrct2/object/ObjectManager.cpp +++ b/src/openrct2/object/ObjectManager.cpp @@ -36,7 +36,8 @@ class ObjectManager final : public IObjectManager { private: IObjectRepository& _objectRepository; - std::vector> _loadedObjects; + + std::vector _loadedObjects; std::array, RIDE_TYPE_COUNT> _rideTypeToObjectMap; // Used to return a safe empty vector back from GetAllRideEntries, can be removed when std::span is available @@ -63,7 +64,7 @@ public: { return nullptr; } - return _loadedObjects[index].get(); + return _loadedObjects[index]; } Object* GetLoadedObject(ObjectType objectType, size_t index) override @@ -82,13 +83,11 @@ public: Object* GetLoadedObject(const ObjectEntryDescriptor& entry) override { - Object* loadedObject = nullptr; const ObjectRepositoryItem* ori = _objectRepository.FindObject(entry); - if (ori != nullptr) - { - loadedObject = ori->LoadedObject; - } - return loadedObject; + if (ori == nullptr) + return nullptr; + + return ori->LoadedObject.get(); } ObjectEntryIndex GetLoadedObjectEntryIndex(const Object* object) override @@ -142,7 +141,7 @@ public: const ObjectRepositoryItem* ori = _objectRepository.FindObject(&entry); if (ori != nullptr) { - Object* loadedObject = ori->LoadedObject; + Object* loadedObject = ori->LoadedObject.get(); if (loadedObject != nullptr) { UnloadObject(loadedObject); @@ -162,7 +161,7 @@ public: { for (auto& object : _loadedObjects) { - UnloadObject(object.get()); + UnloadObject(object); } UpdateSceneryGroupIndexes(); ResetTypeToRideEntryIndexMap(); @@ -333,42 +332,42 @@ private: Object* RepositoryItemToObject(const ObjectRepositoryItem* ori, std::optional slot = {}) { - Object* loadedObject = nullptr; - if (ori != nullptr) + if (ori == nullptr) + return nullptr; + + Object* loadedObject = ori->LoadedObject.get(); + if (loadedObject != nullptr) + return loadedObject; + + ObjectType objectType = ori->ObjectEntry.GetType(); + if (slot) { - loadedObject = ori->LoadedObject; - if (loadedObject == nullptr) + if (_loadedObjects.size() > static_cast(*slot) && _loadedObjects[*slot] != nullptr) { - ObjectType objectType = ori->ObjectEntry.GetType(); - if (slot) - { - if (_loadedObjects.size() > static_cast(*slot) && _loadedObjects[*slot] != nullptr) - { - // Slot already taken - return nullptr; - } - } - else - { - slot = FindSpareSlot(objectType); - } - if (slot) - { - auto object = GetOrLoadObject(ori); - if (object != nullptr) - { - if (_loadedObjects.size() <= static_cast(*slot)) - { - _loadedObjects.resize(*slot + 1); - } - loadedObject = object.get(); - _loadedObjects[*slot] = std::move(object); - UpdateSceneryGroupIndexes(); - ResetTypeToRideEntryIndexMap(); - } - } + // Slot already taken + return nullptr; } } + else + { + slot = FindSpareSlot(objectType); + } + if (slot) + { + auto object = GetOrLoadObject(ori); + if (object != nullptr) + { + if (_loadedObjects.size() <= static_cast(*slot)) + { + _loadedObjects.resize(*slot + 1); + } + loadedObject = object; + _loadedObjects[*slot] = object; + UpdateSceneryGroupIndexes(); + ResetTypeToRideEntryIndexMap(); + } + } + return loadedObject; } @@ -396,8 +395,7 @@ private: Guard::ArgumentNotNull(object, GUARD_LINE); auto result = std::numeric_limits().max(); - auto it = std::find_if( - _loadedObjects.begin(), _loadedObjects.end(), [object](auto& obj) { return obj.get() == object; }); + auto it = std::find_if(_loadedObjects.begin(), _loadedObjects.end(), [object](auto& obj) { return obj == object; }); if (it != _loadedObjects.end()) { result = std::distance(_loadedObjects.begin(), it); @@ -405,7 +403,7 @@ private: return result; } - void SetNewLoadedObjectList(std::vector>&& newLoadedObjects) + void SetNewLoadedObjectList(std::vector&& newLoadedObjects) { if (newLoadedObjects.empty()) { @@ -420,30 +418,24 @@ private: void UnloadObject(Object* object) { - if (object != nullptr) + if (object == nullptr) + return; + + object->Unload(); + + // TODO try to prevent doing a repository search + const ObjectRepositoryItem* ori = _objectRepository.FindObject(object->GetObjectEntry()); + if (ori != nullptr) { - object->Unload(); - - // TODO try to prevent doing a repository search - const ObjectRepositoryItem* ori = _objectRepository.FindObject(object->GetObjectEntry()); - if (ori != nullptr) - { - _objectRepository.UnregisterLoadedObject(ori, object); - } - - // Because it's possible to have the same loaded object for multiple - // slots, we have to make sure find and set all of them to nullptr - for (auto& obj : _loadedObjects) - { - if (obj.get() == object) - { - obj = nullptr; - } - } + _objectRepository.UnregisterLoadedObject(ori, object); } + + // Because it's possible to have the same loaded object for multiple + // slots, we have to make sure find and set all of them to nullptr + std::replace(_loadedObjects.begin(), _loadedObjects.end(), object, static_cast(nullptr)); } - void UnloadObjectsExcept(const std::vector>& newLoadedObjects) + void UnloadObjectsExcept(const std::vector& newLoadedObjects) { // Build a hash set for quick checking auto exceptSet = std::unordered_set(); @@ -451,7 +443,7 @@ private: { if (object != nullptr) { - exceptSet.insert(object.get()); + exceptSet.insert(object); } } @@ -463,9 +455,9 @@ private: if (object != nullptr) { totalObjectsLoaded++; - if (exceptSet.find(object.get()) == exceptSet.end()) + if (exceptSet.find(object) == exceptSet.end()) { - UnloadObject(object.get()); + UnloadObject(object); numObjectsUnloaded++; } } @@ -478,50 +470,51 @@ private: { for (auto& loadedObject : _loadedObjects) { - if (loadedObject != nullptr) + // The list can contain unused slots, skip them. + if (loadedObject == nullptr) + continue; + + switch (loadedObject->GetObjectType()) { - switch (loadedObject->GetObjectType()) + case ObjectType::SmallScenery: { - case ObjectType::SmallScenery: - { - auto* sceneryEntry = static_cast(loadedObject->GetLegacyData()); - sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); - break; - } - case ObjectType::LargeScenery: - { - auto* sceneryEntry = static_cast(loadedObject->GetLegacyData()); - sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); - break; - } - case ObjectType::Walls: - { - auto* wallEntry = static_cast(loadedObject->GetLegacyData()); - wallEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); - break; - } - case ObjectType::Banners: - { - auto* bannerEntry = static_cast(loadedObject->GetLegacyData()); - bannerEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); - break; - } - case ObjectType::PathBits: - { - auto* pathBitEntry = static_cast(loadedObject->GetLegacyData()); - pathBitEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); - break; - } - case ObjectType::SceneryGroup: - { - auto sgObject = dynamic_cast(loadedObject.get()); - sgObject->UpdateEntryIndexes(); - break; - } - default: - // This switch only handles scenery ObjectTypes. - break; + auto* sceneryEntry = static_cast(loadedObject->GetLegacyData()); + sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject); + break; } + case ObjectType::LargeScenery: + { + auto* sceneryEntry = static_cast(loadedObject->GetLegacyData()); + sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject); + break; + } + case ObjectType::Walls: + { + auto* wallEntry = static_cast(loadedObject->GetLegacyData()); + wallEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject); + break; + } + case ObjectType::Banners: + { + auto* bannerEntry = static_cast(loadedObject->GetLegacyData()); + bannerEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject); + break; + } + case ObjectType::PathBits: + { + auto* pathBitEntry = static_cast(loadedObject->GetLegacyData()); + pathBitEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject); + break; + } + case ObjectType::SceneryGroup: + { + auto sgObject = dynamic_cast(loadedObject); + sgObject->UpdateEntryIndexes(); + break; + } + default: + // This switch only handles scenery ObjectTypes. + break; } } @@ -583,7 +576,7 @@ private: } else { - auto loadedObject = ori->LoadedObject; + auto* loadedObject = ori->LoadedObject.get(); if (loadedObject == nullptr) { auto object = _objectRepository.LoadObject(ori); @@ -651,61 +644,51 @@ private: } } - std::vector> LoadObjects( - std::vector& requiredObjects, size_t* outNewObjectsLoaded) + std::vector LoadObjects(std::vector& requiredObjects, size_t* outNewObjectsLoaded) { - std::vector> objects; - std::vector loadedObjects; + std::vector objects; + std::vector newLoadedObjects; std::vector badObjects; objects.resize(OBJECT_ENTRY_COUNT); - loadedObjects.reserve(OBJECT_ENTRY_COUNT); + newLoadedObjects.reserve(OBJECT_ENTRY_COUNT); // Read objects std::mutex commonMutex; - ParallelFor(requiredObjects, [this, &commonMutex, requiredObjects, &objects, &badObjects, &loadedObjects](size_t i) { + ParallelFor(requiredObjects, [this, &commonMutex, requiredObjects, &objects, &badObjects, &newLoadedObjects](size_t i) { auto requiredObject = requiredObjects[i]; - std::unique_ptr object; + Object* object = nullptr; if (requiredObject != nullptr) { - auto loadedObject = requiredObject->LoadedObject; + auto loadedObject = requiredObject->LoadedObject.get(); if (loadedObject == nullptr) { // Object requires to be loaded, if the object successfully loads it will register it // as a loaded object otherwise placed into the badObjects list. - object = _objectRepository.LoadObject(requiredObject); + auto newObject = _objectRepository.LoadObject(requiredObject); std::lock_guard guard(commonMutex); - if (object == nullptr) + if (newObject == nullptr) { badObjects.push_back(requiredObject->ObjectEntry); ReportObjectLoadProblem(&requiredObject->ObjectEntry); } else { - loadedObjects.push_back(object.get()); + object = newObject.get(); + newLoadedObjects.push_back(object); // Connect the ori to the registered object - _objectRepository.RegisterLoadedObject(requiredObject, object.get()); + _objectRepository.RegisterLoadedObject(requiredObject, std::move(newObject)); } } else { - // The object is already loaded, given that the new list will be used as the next loaded object list, - // we can move the element out safely. This is required as the resulting list must contain all loaded - // objects and not just the newly loaded ones. - std::lock_guard guard(commonMutex); - auto it = std::find_if(_loadedObjects.begin(), _loadedObjects.end(), [loadedObject](const auto& obj) { - return obj.get() == loadedObject; - }); - if (it != _loadedObjects.end()) - { - object = std::move(*it); - } + object = loadedObject; } } - objects[i] = std::move(object); + objects[i] = object; }); // Load objects - for (auto obj : loadedObjects) + for (auto obj : newLoadedObjects) { obj->Load(); } @@ -713,7 +696,7 @@ private: if (!badObjects.empty()) { // Unload all the new objects we loaded - for (auto object : loadedObjects) + for (auto object : newLoadedObjects) { UnloadObject(object); } @@ -722,28 +705,30 @@ private: if (outNewObjectsLoaded != nullptr) { - *outNewObjectsLoaded = loadedObjects.size(); + *outNewObjectsLoaded = newLoadedObjects.size(); } return objects; } - std::unique_ptr GetOrLoadObject(const ObjectRepositoryItem* ori) + Object* GetOrLoadObject(const ObjectRepositoryItem* ori) { - std::unique_ptr object; - auto loadedObject = ori->LoadedObject; - if (loadedObject == nullptr) - { - // Try to load object - object = _objectRepository.LoadObject(ori); - if (object != nullptr) - { - object->Load(); + auto* loadedObject = ori->LoadedObject.get(); + if (loadedObject != nullptr) + return loadedObject; - // Connect the ori to the registered object - _objectRepository.RegisterLoadedObject(ori, object.get()); - } + // Try to load object + auto object = _objectRepository.LoadObject(ori); + if (object != nullptr) + { + loadedObject = object.get(); + + object->Load(); + + // Connect the ori to the registered object + _objectRepository.RegisterLoadedObject(ori, std::move(object)); } - return object; + + return loadedObject; } void ResetTypeToRideEntryIndexMap() diff --git a/src/openrct2/object/ObjectRepository.cpp b/src/openrct2/object/ObjectRepository.cpp index e436e36563..ea19c419bf 100644 --- a/src/openrct2/object/ObjectRepository.cpp +++ b/src/openrct2/object/ObjectRepository.cpp @@ -266,18 +266,18 @@ public: } } - void RegisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) override + void RegisterLoadedObject(const ObjectRepositoryItem* ori, std::unique_ptr&& object) override { ObjectRepositoryItem* item = &_items[ori->Id]; Guard::Assert(item->LoadedObject == nullptr, GUARD_LINE); - item->LoadedObject = object; + item->LoadedObject = std::move(object); } void UnregisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) override { ObjectRepositoryItem* item = &_items[ori->Id]; - if (item->LoadedObject == object) + if (item->LoadedObject.get() == object) { item->LoadedObject = nullptr; } diff --git a/src/openrct2/object/ObjectRepository.h b/src/openrct2/object/ObjectRepository.h index 049fb54fe0..e6de1cce14 100644 --- a/src/openrct2/object/ObjectRepository.h +++ b/src/openrct2/object/ObjectRepository.h @@ -43,7 +43,7 @@ struct ObjectRepositoryItem std::string Name; std::vector Authors; std::vector Sources; - Object* LoadedObject{}; + std::shared_ptr LoadedObject{}; struct { uint8_t RideFlags; @@ -82,7 +82,7 @@ struct IObjectRepository [[nodiscard]] virtual const ObjectRepositoryItem* FindObject(const ObjectEntryDescriptor& oed) const abstract; [[nodiscard]] virtual std::unique_ptr LoadObject(const ObjectRepositoryItem* ori) abstract; - virtual void RegisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) abstract; + virtual void RegisterLoadedObject(const ObjectRepositoryItem* ori, std::unique_ptr&& object) abstract; virtual void UnregisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) abstract; virtual void AddObject(const rct_object_entry* objectEntry, const void* data, size_t dataSize) abstract; diff --git a/src/openrct2/object/SceneryGroupObject.cpp b/src/openrct2/object/SceneryGroupObject.cpp index aa6a8bfd20..166245d8b1 100644 --- a/src/openrct2/object/SceneryGroupObject.cpp +++ b/src/openrct2/object/SceneryGroupObject.cpp @@ -81,7 +81,7 @@ void SceneryGroupObject::UpdateEntryIndexes() if (ori->LoadedObject == nullptr) continue; - auto entryIndex = objectManager.GetLoadedObjectEntryIndex(ori->LoadedObject); + auto entryIndex = objectManager.GetLoadedObjectEntryIndex(ori->LoadedObject.get()); Guard::Assert(entryIndex != OBJECT_ENTRY_INDEX_NULL, GUARD_LINE); auto sceneryType = ori->ObjectEntry.GetSceneryType();