ComputePass.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Asset/AssetCommon.h>
  9. #include <AzCore/Asset/AssetManagerBus.h>
  10. #include <Atom/RHI/CommandList.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/DevicePipelineState.h>
  14. #include <Atom/RPI.Reflect/Pass/ComputePassData.h>
  15. #include <Atom/RPI.Reflect/Pass/PassTemplate.h>
  16. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  17. #include <Atom/RPI.Public/Pass/ComputePass.h>
  18. #include <Atom/RPI.Public/Pass/PassUtils.h>
  19. #include <Atom/RPI.Public/RPIUtils.h>
  20. #include <Atom/RPI.Public/Shader/Shader.h>
  21. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  22. using uint = uint32_t;
  23. using uint4 = uint[4];
  24. #include "../../../Feature/Common/Assets/ShaderLib/Atom/Features/IndirectRendering.azsli"
  25. namespace AZ
  26. {
  27. namespace RPI
  28. {
  29. ComputePass::~ComputePass()
  30. {
  31. ShaderReloadNotificationBus::Handler::BusDisconnect();
  32. }
  33. Ptr<ComputePass> ComputePass::Create(const PassDescriptor& descriptor)
  34. {
  35. Ptr<ComputePass> pass = aznew ComputePass(descriptor);
  36. return pass;
  37. }
  38. ComputePass::ComputePass(const PassDescriptor& descriptor, AZ::Name supervariant)
  39. : RenderPass(descriptor)
  40. , m_dispatchItem(RHI::MultiDevice::AllDevices)
  41. , m_passDescriptor(descriptor)
  42. {
  43. m_flags.m_canBecomeASubpass = false;
  44. const ComputePassData* passData = PassUtils::GetPassData<ComputePassData>(m_passDescriptor);
  45. if (passData == nullptr)
  46. {
  47. AZ_Error(
  48. "PassSystem", false, "[ComputePass '%s']: Trying to construct without valid ComputePassData!", GetPathName().GetCStr());
  49. return;
  50. }
  51. m_indirectDispatch = passData->m_indirectDispatch;
  52. m_indirectDispatchBufferSlotName = passData->m_indirectDispatchBufferSlotName;
  53. m_fullscreenDispatch = passData->m_fullscreenDispatch;
  54. m_fullscreenSizeSourceSlotName = passData->m_fullscreenSizeSourceSlotName;
  55. AZ_Assert(
  56. !(m_indirectDispatch && m_fullscreenDispatch),
  57. "[ComputePass '%s']: Only one of the dispatch options (indirect, fullscreen) can be active.",
  58. GetPathName().GetCStr());
  59. RHI::DispatchDirect dispatchArgs;
  60. dispatchArgs.m_totalNumberOfThreadsX = passData->m_totalNumberOfThreadsX;
  61. dispatchArgs.m_totalNumberOfThreadsY = passData->m_totalNumberOfThreadsY;
  62. dispatchArgs.m_totalNumberOfThreadsZ = passData->m_totalNumberOfThreadsZ;
  63. m_dispatchItem.SetArguments(dispatchArgs);
  64. LoadShader(supervariant);
  65. m_defaultShaderAttachmentStage = RHI::ScopeAttachmentStage::ComputeShader;
  66. }
  67. void ComputePass::LoadShader(AZ::Name supervariant)
  68. {
  69. // Load ComputePassData...
  70. const ComputePassData* passData = PassUtils::GetPassData<ComputePassData>(m_passDescriptor);
  71. if (passData == nullptr)
  72. {
  73. AZ_Error("PassSystem", false, "[ComputePass '%s']: Trying to construct without valid ComputePassData!",
  74. GetPathName().GetCStr());
  75. return;
  76. }
  77. // Hardware Queue Class
  78. if (passData->m_useAsyncCompute)
  79. {
  80. m_hardwareQueueClass = RHI::HardwareQueueClass::Compute;
  81. }
  82. // Load Shader
  83. Data::Asset<ShaderAsset> shaderAsset;
  84. if (passData->m_shaderReference.m_assetId.IsValid())
  85. {
  86. shaderAsset = RPI::FindShaderAsset(passData->m_shaderReference.m_assetId, passData->m_shaderReference.m_filePath);
  87. }
  88. if (!shaderAsset.IsReady())
  89. {
  90. AZ_Error("PassSystem", false, "[ComputePass '%s']: Failed to load shader '%s'!",
  91. GetPathName().GetCStr(),
  92. passData->m_shaderReference.m_filePath.data());
  93. return;
  94. }
  95. m_shader = Shader::FindOrCreate(shaderAsset, supervariant);
  96. if (m_shader == nullptr)
  97. {
  98. AZ_Error("PassSystem", false, "[ComputePass '%s']: Failed to create shader instance from asset '%s'!",
  99. GetPathName().GetCStr(),
  100. passData->m_shaderReference.m_filePath.data());
  101. return;
  102. }
  103. // Load Pass SRG...
  104. const auto passSrgLayout = m_shader->FindShaderResourceGroupLayout(SrgBindingSlot::Pass);
  105. if (passSrgLayout)
  106. {
  107. m_shaderResourceGroup = ShaderResourceGroup::Create(shaderAsset, m_shader->GetSupervariantIndex(), passSrgLayout->GetName());
  108. AZ_Assert(m_shaderResourceGroup, "[ComputePass '%s']: Failed to create SRG from shader asset '%s'",
  109. GetPathName().GetCStr(),
  110. passData->m_shaderReference.m_filePath.data());
  111. PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
  112. }
  113. // Load Draw SRG...
  114. const bool compileDrawSrg = false; // The SRG will be compiled in CompileResources()
  115. m_drawSrg = m_shader->CreateDefaultDrawSrg(compileDrawSrg);
  116. if (m_dispatchItem.GetArguments().m_type == RHI::DispatchType::Direct)
  117. {
  118. auto arguments = m_dispatchItem.GetArguments();
  119. const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), arguments.m_direct);
  120. if (!outcome.IsSuccess())
  121. {
  122. AZ_Error(
  123. "PassSystem",
  124. false,
  125. "[ComputePass '%s']: Shader '%.*s' contains invalid numthreads arguments:\n%s",
  126. GetPathName().GetCStr(),
  127. passData->m_shaderReference.m_filePath.size(),
  128. passData->m_shaderReference.m_filePath.data(),
  129. outcome.GetError().c_str());
  130. }
  131. m_dispatchItem.SetArguments(arguments);
  132. }
  133. // Setup pipeline state...
  134. RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
  135. ShaderOptionGroup options = m_shader->GetDefaultShaderOptions();
  136. m_shader->GetDefaultVariant().ConfigurePipelineState(pipelineStateDescriptor, options);
  137. m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
  138. if (m_drawSrg && m_shader->GetDefaultVariant().UseKeyFallback())
  139. {
  140. m_drawSrg->SetShaderVariantKeyFallbackValue(options.GetShaderVariantKeyFallbackValue());
  141. }
  142. OnShaderReloadedInternal();
  143. ShaderReloadNotificationBus::Handler::BusDisconnect();
  144. ShaderReloadNotificationBus::Handler::BusConnect(passData->m_shaderReference.m_assetId);
  145. }
  146. // Scope producer functions
  147. void ComputePass::CompileResources(const RHI::FrameGraphCompileContext& context)
  148. {
  149. if (m_shaderResourceGroup != nullptr)
  150. {
  151. BindPassSrg(context, m_shaderResourceGroup);
  152. m_shaderResourceGroup->Compile();
  153. }
  154. if (m_drawSrg != nullptr)
  155. {
  156. BindSrg(m_drawSrg->GetRHIShaderResourceGroup());
  157. m_drawSrg->Compile();
  158. }
  159. if (m_indirectDispatch && m_indirectDispatchBufferBinding)
  160. {
  161. auto& attachment = m_indirectDispatchBufferBinding->GetAttachment();
  162. AZ_Assert(
  163. attachment,
  164. "[ComputePass '%s']: Indirect dispatch buffer slot %s has no attachment.",
  165. GetPathName().GetCStr(),
  166. m_indirectDispatchBufferBinding->m_name.GetCStr());
  167. if (attachment)
  168. {
  169. auto buffer = context.GetBuffer(attachment->GetAttachmentId());
  170. AZ_Assert(
  171. buffer,
  172. "[ComputePass '%s']: Attachment connected to Indirect dispatch buffer slot %s has no buffer",
  173. GetPathName().GetCStr(),
  174. m_indirectDispatchBufferBinding->m_name.GetCStr());
  175. m_indirectDispatchBufferView = {
  176. *buffer, *m_indirectDispatchBufferSignature, 0, sizeof(DispatchIndirectCommand), sizeof(DispatchIndirectCommand)
  177. };
  178. AZ::RHI::DispatchIndirect dispatchArgs(1, m_indirectDispatchBufferView, 0);
  179. m_dispatchItem.SetArguments(dispatchArgs);
  180. }
  181. }
  182. else if (m_fullscreenDispatch && m_fullscreenSizeSourceBinding)
  183. {
  184. auto& attachment = m_fullscreenSizeSourceBinding->GetAttachment();
  185. AZ_Assert(
  186. attachment,
  187. "[ComputePass '%s']: Slot %s has no attachment for fullscreen size source.",
  188. GetPathName().GetCStr(),
  189. m_fullscreenSizeSourceBinding->m_name.GetCStr());
  190. if (attachment)
  191. {
  192. AZ_Assert(
  193. attachment->GetAttachmentType() == AZ::RHI::AttachmentType::Image,
  194. "[ComputePass '%s']: Slot %s must be an image for fullscreen size source.",
  195. GetPathName().GetCStr(),
  196. m_fullscreenSizeSourceBinding->m_name.GetCStr());
  197. auto imageDescriptor = context.GetImageDescriptor(attachment->GetAttachmentId());
  198. // We are using the ArraySize or the image depth, whichever is bigger.
  199. // Note that this will fail for an array of 3d textures.
  200. auto depth = AZStd::max(imageDescriptor.m_size.m_depth, static_cast<uint32_t>(imageDescriptor.m_arraySize));
  201. SetTargetThreadCounts(imageDescriptor.m_size.m_width, imageDescriptor.m_size.m_height, depth);
  202. }
  203. }
  204. }
  205. void ComputePass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  206. {
  207. RHI::CommandList* commandList = context.GetCommandList();
  208. SetSrgsForDispatch(context);
  209. commandList->Submit(m_dispatchItem.GetDeviceDispatchItem(context.GetDeviceIndex()));
  210. }
  211. void ComputePass::SetTargetThreadCounts(uint32_t targetThreadCountX, uint32_t targetThreadCountY, uint32_t targetThreadCountZ)
  212. {
  213. auto arguments{m_dispatchItem.GetArguments()};
  214. arguments.m_direct.m_totalNumberOfThreadsX = targetThreadCountX;
  215. arguments.m_direct.m_totalNumberOfThreadsY = targetThreadCountY;
  216. arguments.m_direct.m_totalNumberOfThreadsZ = targetThreadCountZ;
  217. m_dispatchItem.SetArguments(arguments);
  218. }
  219. Data::Instance<ShaderResourceGroup> ComputePass::GetShaderResourceGroup() const
  220. {
  221. return m_shaderResourceGroup;
  222. }
  223. Data::Instance<Shader> ComputePass::GetShader() const
  224. {
  225. return m_shader;
  226. }
  227. void ComputePass::BuildInternal()
  228. {
  229. RenderPass::BuildInternal();
  230. if (m_indirectDispatch)
  231. {
  232. m_indirectDispatchBufferBinding = nullptr;
  233. if (!m_indirectDispatchBufferSlotName.IsEmpty())
  234. {
  235. m_indirectDispatchBufferBinding = FindAttachmentBinding(m_indirectDispatchBufferSlotName);
  236. AZ_Assert(m_indirectDispatchBufferBinding,
  237. "[ComputePass '%s']: Indirect dispatch buffer slot %s not found.",
  238. GetPathName().GetCStr(),
  239. m_indirectDispatchBufferSlotName.GetCStr());
  240. if (m_indirectDispatchBufferBinding)
  241. {
  242. AZ_Assert(
  243. m_indirectDispatchBufferBinding->m_scopeAttachmentUsage == AZ::RHI::ScopeAttachmentUsage::Indirect,
  244. "[ComputePass '%s']: Indirect dispatch buffer slot %s needs ScopeAttachmentUsage::Indirect.",
  245. GetPathName().GetCStr(),
  246. m_indirectDispatchBufferSlotName.GetCStr())
  247. }
  248. }
  249. else
  250. {
  251. for (auto& binding : m_attachmentBindings)
  252. {
  253. if (binding.m_scopeAttachmentUsage == AZ::RHI::ScopeAttachmentUsage::Indirect)
  254. {
  255. m_indirectDispatchBufferBinding = &binding;
  256. break;
  257. }
  258. }
  259. AZ_Assert(
  260. m_indirectDispatchBufferBinding,
  261. "[ComputePass '%s']: No valid indirect dispatch buffer slot found.",
  262. GetPathName().GetCStr());
  263. }
  264. AZ::RHI::IndirectBufferLayout indirectDispatchBufferLayout;
  265. indirectDispatchBufferLayout.AddIndirectCommand(AZ::RHI::IndirectCommandDescriptor(AZ::RHI::IndirectCommandType::Dispatch));
  266. if (!indirectDispatchBufferLayout.Finalize())
  267. {
  268. AZ_Assert(false, "[ComputePass '%s']: Failed to finalize Indirect Layout", GetPathName().GetCStr());
  269. }
  270. m_indirectDispatchBufferSignature = aznew AZ::RHI::IndirectBufferSignature;
  271. AZ::RHI::IndirectBufferSignatureDescriptor signatureDescriptor{};
  272. signatureDescriptor.m_layout = indirectDispatchBufferLayout;
  273. [[maybe_unused]] auto result =
  274. m_indirectDispatchBufferSignature->Init(AZ::RHI::MultiDevice::AllDevices, signatureDescriptor);
  275. AZ_Assert(
  276. result == AZ::RHI::ResultCode::Success,
  277. "[ComputePass '%s']: Failed to initialize Indirect Buffer Signature",
  278. GetPathName().GetCStr());
  279. }
  280. else if (m_fullscreenDispatch)
  281. {
  282. m_fullscreenSizeSourceBinding = nullptr;
  283. if (!m_fullscreenSizeSourceSlotName.IsEmpty())
  284. {
  285. m_fullscreenSizeSourceBinding = FindAttachmentBinding(m_fullscreenSizeSourceSlotName);
  286. AZ_Assert(
  287. m_fullscreenSizeSourceBinding,
  288. "[ComputePass '%s']: Fullscreen size source slot %s not found.",
  289. GetPathName().GetCStr(),
  290. m_fullscreenSizeSourceSlotName.GetCStr());
  291. }
  292. else
  293. {
  294. if (GetOutputCount() > 0)
  295. {
  296. m_fullscreenSizeSourceBinding = &GetOutputBinding(0);
  297. }
  298. else if (!m_fullscreenSizeSourceBinding && GetInputOutputCount() > 0)
  299. {
  300. m_fullscreenSizeSourceBinding = &GetInputOutputBinding(0);
  301. }
  302. AZ_Assert(
  303. m_fullscreenSizeSourceBinding,
  304. "[ComputePass '%s']: No valid Output or InputOutput slot as a fullscreen size source found.",
  305. GetPathName().GetCStr());
  306. }
  307. }
  308. }
  309. void ComputePass::OnShaderReinitialized(const Shader& shader)
  310. {
  311. AZ_UNUSED(shader);
  312. LoadShader();
  313. }
  314. void ComputePass::OnShaderAssetReinitialized(const Data::Asset<ShaderAsset>& shaderAsset)
  315. {
  316. AZ_UNUSED(shaderAsset);
  317. LoadShader();
  318. }
  319. void ComputePass::OnShaderVariantReinitialized(const ShaderVariant&)
  320. {
  321. LoadShader();
  322. }
  323. void ComputePass::SetComputeShaderReloadedCallback(ComputeShaderReloadedCallback callback)
  324. {
  325. m_shaderReloadedCallback = callback;
  326. }
  327. void ComputePass::UpdateShaderOptions(const ShaderVariantId& shaderVariantId)
  328. {
  329. const ShaderVariant& shaderVariant = m_shader->GetVariant(shaderVariantId);
  330. RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
  331. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderVariantId);
  332. m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
  333. if (m_drawSrg && shaderVariant.UseKeyFallback())
  334. {
  335. m_drawSrg->SetShaderVariantKeyFallbackValue(shaderVariantId.m_key);
  336. }
  337. }
  338. void ComputePass::OnShaderReloadedInternal()
  339. {
  340. if (m_shaderReloadedCallback)
  341. {
  342. m_shaderReloadedCallback(this);
  343. }
  344. }
  345. } // namespace RPI
  346. } // namespace AZ