PassLibrary.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Interface/Interface.h>
  9. #include <Atom/RHI/RHIUtils.h>
  10. #include <Atom/RPI.Public/RenderPipeline.h>
  11. #include <Atom/RPI.Public/Pass/Pass.h>
  12. #include <Atom/RPI.Public/Pass/PassFilter.h>
  13. #include <Atom/RPI.Public/Pass/PassSystemBus.h>
  14. #include <Atom/RPI.Public/Pass/PassSystemInterface.h>
  15. #include <Atom/RPI.Public/Pass/PassLibrary.h>
  16. #include <Atom/RPI.Reflect/Pass/PassAsset.h>
  17. #include <Atom/RPI.Reflect/Pass/ComputePassData.h>
  18. #include <Atom/RPI.Reflect/Asset/AssetUtils.h>
  19. namespace AZ
  20. {
  21. namespace RPI
  22. {
  23. // Initialization & Shutdown...
  24. void PassLibrary::Init()
  25. {
  26. AddCoreTemplates();
  27. }
  28. void PassLibrary::Shutdown()
  29. {
  30. m_isShuttingDown = true;
  31. m_passNameMapping.clear();
  32. m_templateEntries.clear();
  33. m_templateMappingAssets.clear();
  34. Data::AssetBus::MultiHandler::BusDisconnect();
  35. }
  36. // Getters...
  37. PassLibrary::TemplateEntry* PassLibrary::GetEntry(const Name& templateName)
  38. {
  39. auto itr = m_templateEntries.find(templateName);
  40. if (itr != m_templateEntries.end())
  41. {
  42. return &(itr->second);
  43. }
  44. return nullptr;
  45. }
  46. const PassLibrary::TemplateEntry* PassLibrary::GetEntry(const Name& templateName) const
  47. {
  48. auto itr = m_templateEntries.find(templateName);
  49. if (itr != m_templateEntries.end())
  50. {
  51. return &(itr->second);
  52. }
  53. return nullptr;
  54. }
  55. const AZStd::shared_ptr<const PassTemplate> PassLibrary::GetPassTemplate(const Name& templateName) const
  56. {
  57. const TemplateEntry* entry = GetEntry(templateName);
  58. return entry ? entry->m_template : nullptr;
  59. }
  60. const AZStd::vector<Pass*>& PassLibrary::GetPassesForTemplate(const Name& templateName) const
  61. {
  62. static AZStd::vector<Pass*> emptyPassList;
  63. const TemplateEntry* entry = GetEntry(templateName);
  64. return entry ? entry->m_passes : emptyPassList;
  65. }
  66. bool PassLibrary::HasTemplate(const Name& templateName) const
  67. {
  68. return m_templateEntries.find(templateName) != m_templateEntries.end();
  69. }
  70. bool PassLibrary::HasPassesForTemplate(const Name& templateName) const
  71. {
  72. return (GetPassesForTemplate(templateName).size() > 0);
  73. }
  74. void PassLibrary::ForEachPass(const PassFilter& passFilter, AZStd::function<PassFilterExecutionFlow(Pass*)> passFunction)
  75. {
  76. uint32_t filterOptions = passFilter.GetEnabledFilterOptions();
  77. // A lambda function which visits each pass in a pass list, if the pass matches the pass filter, then call the pass function
  78. auto visitList = [passFilter, passFunction](const AZStd::vector<Pass*>& passList, uint32_t options) -> PassFilterExecutionFlow
  79. {
  80. if (passList.size() == 0)
  81. {
  82. return PassFilterExecutionFlow::ContinueVisitingPasses;
  83. }
  84. // if there is not other filter options enabled, skip the filter and call pass functions directly
  85. if (options == PassFilter::FilterOptions::Empty)
  86. {
  87. for (Pass* pass : passList)
  88. {
  89. // If user want to skip processing, return directly.
  90. if (passFunction(pass) == PassFilterExecutionFlow::StopVisitingPasses)
  91. {
  92. return PassFilterExecutionFlow::StopVisitingPasses;
  93. }
  94. }
  95. return PassFilterExecutionFlow::ContinueVisitingPasses;
  96. }
  97. // Check with the pass filter and call pass functions
  98. for (Pass* pass : passList)
  99. {
  100. if (passFilter.Matches(pass, options))
  101. {
  102. if (passFunction(pass) == PassFilterExecutionFlow::StopVisitingPasses)
  103. {
  104. return PassFilterExecutionFlow::StopVisitingPasses;
  105. }
  106. }
  107. }
  108. return PassFilterExecutionFlow::ContinueVisitingPasses;
  109. };
  110. // Check pass template name first
  111. if (filterOptions & PassFilter::FilterOptions::PassTemplateName)
  112. {
  113. auto entry = GetEntry(passFilter.GetPassTemplateName());
  114. if (!entry)
  115. {
  116. return;
  117. }
  118. filterOptions &= ~(PassFilter::FilterOptions::PassTemplateName);
  119. visitList(entry->m_passes, filterOptions);
  120. return;
  121. }
  122. else if (filterOptions & PassFilter::FilterOptions::PassName)
  123. {
  124. const auto constItr = m_passNameMapping.find(passFilter.GetPassName());
  125. if (constItr == m_passNameMapping.end())
  126. {
  127. return;
  128. }
  129. filterOptions &= ~(PassFilter::FilterOptions::PassName);
  130. visitList(constItr->second, filterOptions);
  131. return;
  132. }
  133. // check againest every passes. This might be slow
  134. AZ_PROFILE_SCOPE(RPI, "PassLibrary::ForEachPass");
  135. for (auto& namePasses : m_passNameMapping)
  136. {
  137. if (visitList(namePasses.second, filterOptions) == PassFilterExecutionFlow::StopVisitingPasses)
  138. {
  139. return;
  140. }
  141. }
  142. }
  143. // Add Functions...
  144. void PassLibrary::AddPass(Pass* pass)
  145. {
  146. if (pass->m_template)
  147. {
  148. TemplateEntry* entry = GetEntry(pass->m_template->m_name);
  149. if (entry)
  150. {
  151. entry->m_passes.push_back(pass);
  152. }
  153. }
  154. m_passNameMapping[pass->m_name].push_back(pass);
  155. }
  156. void PassLibrary::AddCoreTemplates()
  157. {
  158. // Put calls to pass template creation functions here...
  159. AddCopyPassTemplate();
  160. }
  161. void PassLibrary::AddCopyPassTemplate()
  162. {
  163. AZStd::shared_ptr<PassTemplate> passTemplate = AZStd::make_shared<PassTemplate>();
  164. passTemplate->m_passClass = "CopyPass";
  165. passTemplate->m_name = "CopyPassTemplate";
  166. PassSlot inputSlot;
  167. inputSlot.m_name = "Input";
  168. inputSlot.m_slotType = PassSlotType::Input;
  169. inputSlot.m_scopeAttachmentUsage = RHI::ScopeAttachmentUsage::Copy;
  170. inputSlot.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  171. passTemplate->m_slots.emplace_back(inputSlot);
  172. PassSlot outputSlot;
  173. outputSlot.m_name = "Output";
  174. outputSlot.m_slotType = PassSlotType::Output;
  175. outputSlot.m_scopeAttachmentUsage = RHI::ScopeAttachmentUsage::Copy;
  176. outputSlot.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  177. passTemplate->m_slots.emplace_back(outputSlot);
  178. AddPassTemplate(passTemplate->m_name, std::move(passTemplate));
  179. passTemplate = AZStd::make_shared<PassTemplate>();
  180. passTemplate->m_passClass = "CopyPass";
  181. passTemplate->m_name = "MultiDeviceCopyPassTemplate";
  182. PassSlot inputOutputSlot;
  183. inputOutputSlot.m_name = "InputOutput";
  184. inputOutputSlot.m_slotType = PassSlotType::InputOutput;
  185. inputOutputSlot.m_scopeAttachmentUsage = RHI::ScopeAttachmentUsage::Copy;
  186. inputOutputSlot.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  187. passTemplate->m_slots.emplace_back(inputOutputSlot);
  188. AddPassTemplate(passTemplate->m_name, std::move(passTemplate));
  189. }
  190. bool PassLibrary::AddPassTemplate(const Name& name, const AZStd::shared_ptr<PassTemplate>& passTemplate, bool hotReloading)
  191. {
  192. // Check if template already exists (unless we're hot reloading)
  193. if (!hotReloading && GetPassTemplate(name) != nullptr)
  194. {
  195. AZ_Warning("PassLibrary", false,
  196. "Trying to add a PassTemplate that already exists in PassLibrary. Template name: %s", name.GetCStr());
  197. return false;
  198. }
  199. if (!passTemplate)
  200. {
  201. AZ_Warning("PassLibrary", false,
  202. "Trying to add a null PassTemplate. Template name: %s", name.GetCStr());
  203. return false;
  204. }
  205. if (passTemplate->m_name != name)
  206. {
  207. AZ_Warning("PassLibrary", false,
  208. "Pass template alias [%s] is different than its name [%s]", name.GetCStr(), passTemplate->m_name.GetCStr());
  209. passTemplate->m_name = name;
  210. }
  211. // Signal that the pass template is being added in case somebody wants to add attachments.
  212. PassSystemTemplateNotificationsBus::Event(
  213. name, &PassSystemTemplateNotificationsBus::Events::OnAddingPassTemplate, passTemplate);
  214. ValidateDeviceFormats(passTemplate);
  215. m_templateEntries[name].m_template = std::move(passTemplate);
  216. return true;
  217. }
  218. void PassLibrary::RemovePassTemplate(const Name& name)
  219. {
  220. auto itr = m_templateEntries.find(name);
  221. if (itr != m_templateEntries.end())
  222. {
  223. AZ_Assert(itr->second.m_passes.empty(), "Can not delete PassTemplate '%s' because there are %zu Passes referencing it",
  224. name.GetCStr(), itr->second.m_passes.size());
  225. AZ_Assert(!itr->second.m_mappingAssetId.IsValid(), "Can not delete PassTemplate '%s' because it was created from an asset",
  226. name.GetCStr());
  227. m_templateEntries.erase(itr);
  228. }
  229. }
  230. void PassLibrary::RemovePassFromLibrary(Pass* pass)
  231. {
  232. if (m_isShuttingDown)
  233. {
  234. return;
  235. }
  236. // Remove from associated template
  237. if (pass->m_template)
  238. {
  239. TemplateEntry* entry = GetEntry(pass->m_template->m_name);
  240. if (entry)
  241. {
  242. [[maybe_unused]] auto iter = AZStd::remove(entry->m_passes.begin(), entry->m_passes.end(), pass);
  243. AZ_Assert((iter + 1) == entry->m_passes.end(),
  244. "Pass [%s] is being deleted but was not registered with it's PassTemlate [%s] in the PassLibrary.",
  245. pass->m_name.GetCStr(), pass->m_template->m_name.GetCStr());
  246. // Delete the pass that is now at the end of the list
  247. entry->m_passes.pop_back();
  248. }
  249. }
  250. // Remove pass from pass name
  251. AZ_Assert(m_passNameMapping.find(pass->GetName()) != m_passNameMapping.end(),
  252. "Pass [%s] is trying to be removed from PassLibrary but was not found in library",
  253. pass->GetName().GetCStr());
  254. AZStd::vector<Pass*>& passes = m_passNameMapping[pass->GetName()];
  255. for (auto itr = passes.begin(); itr != passes.end(); itr++)
  256. {
  257. if (*itr == pass)
  258. {
  259. passes.erase(itr);
  260. return;
  261. }
  262. }
  263. }
  264. // Pass Asset Functions...
  265. void PassLibrary::OnAssetReloaded(Data::Asset<Data::AssetData> asset)
  266. {
  267. // Handle pass asset reload
  268. Data::Asset<PassAsset> passAsset = { asset.GetAs<PassAsset>() , AZ::Data::AssetLoadBehavior::PreLoad};
  269. if (passAsset && passAsset->GetPassTemplate())
  270. {
  271. LoadPassAsset(passAsset->GetPassTemplate()->m_name, passAsset, true);
  272. return;
  273. }
  274. // Handle template mapping reload
  275. // Note: it's a known issue that when mapping asset got reloaded, we only handle the new entries
  276. Data::Asset<AnyAsset> templateMappings = { asset.GetAs<AnyAsset>(), AZ::Data::AssetLoadBehavior::PreLoad };
  277. if (templateMappings)
  278. {
  279. auto itr = m_templateMappingAssets.find(asset->GetId());
  280. if (itr != m_templateMappingAssets.end())
  281. {
  282. LoadPassTemplateMappings(templateMappings);
  283. }
  284. }
  285. }
  286. bool PassLibrary::LoadPassAsset(const Name& name, const Data::Asset<PassAsset>& passAsset, bool hotReloading)
  287. {
  288. if (!passAsset.IsReady())
  289. {
  290. AZ_Error("PassAsset", false, "Failed to get pass asset. %s", passAsset.ToString<AZStd::string>().c_str());
  291. return false;
  292. }
  293. if (!passAsset->GetPassTemplate())
  294. {
  295. AZ_Error("PassAsset", false, "Pass asset does not contain a pass template. %s", passAsset.ToString<AZStd::string>().c_str());
  296. return false;
  297. }
  298. AZStd::shared_ptr<PassTemplate> passTemplate = passAsset->GetPassTemplate()->Clone();
  299. bool success = AddPassTemplate(name, std::move(passTemplate), hotReloading);
  300. if (success)
  301. {
  302. TemplateEntry& entry = m_templateEntries[name];
  303. entry.m_asset = passAsset;
  304. if (hotReloading)
  305. {
  306. for (Pass* pass : entry.m_passes)
  307. {
  308. if (pass->m_pipeline)
  309. {
  310. pass->m_pipeline->MarkPipelinePassChanges(PipelinePassChanges::PassAssetHotReloaded);
  311. }
  312. }
  313. }
  314. }
  315. return success;
  316. }
  317. bool PassLibrary::LoadPassAsset(const Name& name, const Data::AssetId& passAssetId)
  318. {
  319. Data::Asset<PassAsset> passAsset;
  320. if (passAssetId.IsValid())
  321. {
  322. passAsset = Data::AssetManager::Instance().GetAsset<RPI::PassAsset>(passAssetId, AZ::Data::AssetLoadBehavior::PreLoad);
  323. passAsset.BlockUntilLoadComplete();
  324. }
  325. bool loadSuccess = LoadPassAsset(name, passAsset);
  326. if (loadSuccess)
  327. {
  328. Data::AssetBus::MultiHandler::BusConnect(passAssetId);
  329. }
  330. return loadSuccess;
  331. }
  332. bool PassLibrary::LoadPassTemplateMappings(const AZStd::string& templateMappingPath)
  333. {
  334. Data::Asset<AnyAsset> mappingAsset = AssetUtils::LoadCriticalAsset<AnyAsset>(templateMappingPath.c_str(), AssetUtils::TraceLevel::Error);
  335. if (m_templateMappingAssets.find(mappingAsset.GetId()) != m_templateMappingAssets.end())
  336. {
  337. AZ_Warning("PassLibrary", false, "Pass template mapping [%s] was already loaded", mappingAsset.GetHint().c_str());
  338. return true;
  339. }
  340. bool success = LoadPassTemplateMappings(mappingAsset);
  341. if (success)
  342. {
  343. Data::AssetBus::MultiHandler::BusConnect(mappingAsset->GetId());
  344. }
  345. return success;
  346. }
  347. bool PassLibrary::LoadPassTemplateMappings(Data::Asset<AnyAsset> mappingAsset)
  348. {
  349. if (mappingAsset.IsReady())
  350. {
  351. const AssetAliases* mappings = GetDataFromAnyAsset<AssetAliases>(mappingAsset);
  352. if (mappings == nullptr)
  353. {
  354. AZ_Error("PassLibrary", false, "Asset [%s] doesn't have assetAliases data", mappingAsset.GetHint().c_str());
  355. return false;
  356. }
  357. const AZStd::unordered_map<AZStd::string, Data::AssetId>& assetMapping = mappings->GetAssetMapping();
  358. Data::AssetId mappingAssetId = mappingAsset.GetId();
  359. m_templateEntries.reserve(m_templateEntries.size() + assetMapping.size());
  360. for (const auto& assetInfo : assetMapping)
  361. {
  362. Name templateName = AZ::Name(assetInfo.first);
  363. if (!HasTemplate(templateName))
  364. {
  365. bool loaded = LoadPassAsset(templateName, assetInfo.second);
  366. if (loaded)
  367. {
  368. auto& entry = m_templateEntries[templateName];
  369. entry.m_mappingAssetId = mappingAssetId;
  370. }
  371. }
  372. else
  373. {
  374. // Report a warning if the template was setup in another mappping asset.
  375. // We won't report a warning if the template was loaded from same asset. This only happens when the asset got reloaded.
  376. if (m_templateEntries[templateName].m_mappingAssetId != mappingAssetId)
  377. {
  378. AZ_Warning("PassLibrary", false, "Template [%s] was aleady added to the library. Duplicated template from [%s]",
  379. templateName.GetCStr(), mappingAsset.ToString<AZStd::string>().c_str());
  380. }
  381. }
  382. }
  383. m_templateMappingAssets[mappingAsset->GetId()] = mappingAsset;
  384. return true;
  385. }
  386. return false;
  387. }
  388. void PassLibrary::ValidateDeviceFormats(const AZStd::shared_ptr<PassTemplate>& passTemplate)
  389. {
  390. // Validate image attachments
  391. for (PassImageAttachmentDesc& imageAttachment : passTemplate->m_imageAttachments)
  392. {
  393. RHI::Format format = imageAttachment.m_imageDescriptor.m_format;
  394. AZStd::string formatLocation = AZStd::string::format("PassAttachmentDesc [%s] on PassTemplate [%s]", imageAttachment.m_name.GetCStr(), passTemplate->m_name.GetCStr());
  395. imageAttachment.m_imageDescriptor.m_format = RHI::ValidateFormat(format, formatLocation.c_str(), imageAttachment.m_formatFallbacks);
  396. }
  397. // Validate slot views
  398. for (PassSlot& slot : passTemplate->m_slots)
  399. {
  400. if (slot.m_imageViewDesc)
  401. {
  402. RHI::Format format = slot.m_imageViewDesc->m_overrideFormat;
  403. AZStd::string formatLocation = AZStd::string::format("ImageViewDescriptor on Slot [%s] in PassTemplate [%s]", slot.m_name.GetCStr(), passTemplate->m_name.GetCStr());
  404. RHI::FormatCapabilities capabilities = RHI::GetCapabilities(slot.m_scopeAttachmentUsage, slot.GetAttachmentAccess(), RHI::AttachmentType::Image);
  405. slot.m_imageViewDesc->m_overrideFormat = RHI::ValidateFormat(format, formatLocation.c_str(), slot.m_formatFallbacks, capabilities);
  406. }
  407. if (slot.m_bufferViewDesc)
  408. {
  409. RHI::Format format = slot.m_bufferViewDesc->m_elementFormat;
  410. AZStd::string formatLocation = AZStd::string::format("BufferViewDescriptor on Slot [%s] in PassTemplate [%s]", slot.m_name.GetCStr(), passTemplate->m_name.GetCStr());
  411. RHI::FormatCapabilities capabilities = RHI::GetCapabilities(slot.m_scopeAttachmentUsage, slot.GetAttachmentAccess(), RHI::AttachmentType::Buffer);
  412. slot.m_bufferViewDesc->m_elementFormat = RHI::ValidateFormat(format, formatLocation.c_str(), slot.m_formatFallbacks, capabilities);
  413. }
  414. }
  415. }
  416. } // namespace RPI
  417. } // namespace AZ