RayTracingPass.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Feature/RayTracing/RayTracingPass.h>
  9. #include <Atom/Feature/RayTracing/RayTracingPassData.h>
  10. #include <Atom/RHI/CommandList.h>
  11. #include <Atom/RHI/DeviceDispatchRaysItem.h>
  12. #include <Atom/RHI/DevicePipelineState.h>
  13. #include <Atom/RHI/DispatchRaysItem.h>
  14. #include <Atom/RHI/Factory.h>
  15. #include <Atom/RHI/FrameScheduler.h>
  16. #include <Atom/RHI/RHISystemInterface.h>
  17. #include <Atom/RHI/RHIUtils.h>
  18. #include <Atom/RPI.Public/Base.h>
  19. #include <Atom/RPI.Public/Pass/PassUtils.h>
  20. #include <Atom/RPI.Public/RPIUtils.h>
  21. #include <Atom/RPI.Public/RenderPipeline.h>
  22. #include <Atom/RPI.Public/Scene.h>
  23. #include <Atom/RPI.Public/View.h>
  24. #include <Atom/RPI.Reflect/Pass/PassTemplate.h>
  25. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  26. #include <AzCore/Asset/AssetCommon.h>
  27. #include <AzCore/Asset/AssetManagerBus.h>
  28. #include <RayTracing/RayTracingFeatureProcessor.h>
  29. using uint = uint32_t;
  30. using uint4 = uint[4];
  31. #include "../../../Feature/Common/Assets/ShaderLib/Atom/Features/IndirectRendering.azsli"
  32. namespace AZ
  33. {
  34. namespace Render
  35. {
  36. RPI::Ptr<RayTracingPass> RayTracingPass::Create(const RPI::PassDescriptor& descriptor)
  37. {
  38. RPI::Ptr<RayTracingPass> pass = aznew RayTracingPass(descriptor);
  39. return pass;
  40. }
  41. RayTracingPass::RayTracingPass(const RPI::PassDescriptor& descriptor)
  42. : RenderPass(descriptor)
  43. , m_passDescriptor(descriptor)
  44. , m_dispatchRaysItem(RHI::RHISystemInterface::Get()->GetRayTracingSupport())
  45. {
  46. m_flags.m_canBecomeASubpass = false;
  47. if (RHI::RHISystemInterface::Get()->GetRayTracingSupport() == RHI::MultiDevice::NoDevices)
  48. {
  49. // raytracing is not supported on this platform
  50. SetEnabled(false);
  51. return;
  52. }
  53. m_passData = RPI::PassUtils::GetPassData<RayTracingPassData>(m_passDescriptor);
  54. if (m_passData == nullptr)
  55. {
  56. AZ_Error("PassSystem", false, "RayTracingPass [%s]: Invalid RayTracingPassData", GetPathName().GetCStr());
  57. return;
  58. }
  59. m_indirectDispatch = m_passData->m_indirectDispatch;
  60. m_indirectDispatchBufferSlotName = m_passData->m_indirectDispatchBufferSlotName;
  61. m_fullscreenDispatch = m_passData->m_fullscreenDispatch;
  62. m_fullscreenSizeSourceSlotName = m_passData->m_fullscreenSizeSourceSlotName;
  63. AZ_Assert(
  64. !(m_indirectDispatch && m_fullscreenDispatch),
  65. "[RaytracingPass '%s']: Only one of the dispatch options (indirect, fullscreen) can be active",
  66. GetPathName().GetCStr());
  67. m_defaultShaderAttachmentStage = RHI::ScopeAttachmentStage::RayTracingShader;
  68. CreatePipelineState();
  69. }
  70. RayTracingPass::~RayTracingPass()
  71. {
  72. RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
  73. }
  74. void RayTracingPass::CreatePipelineState()
  75. {
  76. m_rayTracingShaderTable.reset();
  77. m_maxRayLengthInputIndex.Reset();
  78. struct RTShaderLib
  79. {
  80. AZ::Data::AssetId m_shaderAssetId;
  81. AZ::Data::Instance<AZ::RPI::Shader> m_shader;
  82. AZ::RHI::PipelineStateDescriptorForRayTracing m_pipelineStateDescriptor;
  83. AZ::Name m_rayGenerationShaderName;
  84. AZ::Name m_missShaderName;
  85. AZ::Name m_closestHitShaderName;
  86. AZ::Name m_closestHitProceduralShaderName;
  87. };
  88. AZStd::fixed_vector<RTShaderLib, 4> shaderLibs;
  89. auto loadRayTracingShader = [&](auto& assetReference, const AZ::Name& supervariantName = AZ::Name("")) -> RTShaderLib&
  90. {
  91. auto it = std::find_if(
  92. shaderLibs.begin(),
  93. shaderLibs.end(),
  94. [&](auto& entry)
  95. {
  96. return entry.m_shaderAssetId == assetReference.m_assetId;
  97. });
  98. if (it != shaderLibs.end())
  99. {
  100. return *it;
  101. }
  102. auto shaderAsset{ AZ::RPI::FindShaderAsset(assetReference.m_assetId, assetReference.m_filePath) };
  103. AZ_Assert(shaderAsset.IsReady(), "Failed to load shader %s", assetReference.m_filePath.c_str());
  104. auto shader{ AZ::RPI::Shader::FindOrCreate(shaderAsset, supervariantName) };
  105. auto shaderVariant{ shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
  106. AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
  107. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shader->GetDefaultShaderOptions());
  108. auto& shaderLib = shaderLibs.emplace_back();
  109. shaderLib.m_shaderAssetId = assetReference.m_assetId;
  110. shaderLib.m_shader = shader;
  111. shaderLib.m_pipelineStateDescriptor = pipelineStateDescriptor;
  112. return shaderLib;
  113. };
  114. auto& rayGenShaderLib{ loadRayTracingShader(m_passData->m_rayGenerationShaderAssetReference) };
  115. rayGenShaderLib.m_rayGenerationShaderName = m_passData->m_rayGenerationShaderName;
  116. m_rayGenerationShader = rayGenShaderLib.m_shader;
  117. auto& closestHitShaderLib{ loadRayTracingShader(m_passData->m_closestHitShaderAssetReference) };
  118. closestHitShaderLib.m_closestHitShaderName = m_passData->m_closestHitShaderName;
  119. m_closestHitShader = closestHitShaderLib.m_shader;
  120. if (!m_passData->m_closestHitProceduralShaderName.empty())
  121. {
  122. auto& closestHitProceduralShaderLib{ loadRayTracingShader(
  123. m_passData->m_closestHitProceduralShaderAssetReference, AZ::RHI::GetDefaultSupervariantNameWithNoFloat16Fallback()) };
  124. closestHitProceduralShaderLib.m_closestHitProceduralShaderName = m_passData->m_closestHitProceduralShaderName;
  125. m_closestHitProceduralShader = closestHitProceduralShaderLib.m_shader;
  126. }
  127. auto& missShaderLib{ loadRayTracingShader(m_passData->m_missShaderAssetReference) };
  128. missShaderLib.m_missShaderName = m_passData->m_missShaderName;
  129. m_missShader = missShaderLib.m_shader;
  130. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(shaderLibs.front().m_pipelineStateDescriptor);
  131. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  132. // create global srg
  133. const auto& globalSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingGlobalSrgBindingSlot);
  134. AZ_Error("PassSystem", globalSrgLayout != nullptr, "RayTracingPass [%s] Failed to find RayTracingGlobalSrg layout", GetPathName().GetCStr());
  135. m_shaderResourceGroup = RPI::ShaderResourceGroup::Create( m_rayGenerationShader->GetAsset(), m_rayGenerationShader->GetSupervariantIndex(), globalSrgLayout->GetName());
  136. AZ_Assert(m_shaderResourceGroup, "RayTracingPass [%s]: Failed to create RayTracingGlobalSrg", GetPathName().GetCStr());
  137. RPI::PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
  138. // check to see if the shader requires the View, Scene, or RayTracingMaterial Srgs
  139. const auto& viewSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::View);
  140. m_requiresViewSrg = (viewSrgLayout != nullptr);
  141. const auto& sceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Scene);
  142. m_requiresSceneSrg = (sceneSrgLayout != nullptr);
  143. const auto& rayTracingMaterialSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingMaterialSrgBindingSlot);
  144. m_requiresRayTracingMaterialSrg = (rayTracingMaterialSrgLayout != nullptr);
  145. const auto& rayTracingSceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingSceneSrgBindingSlot);
  146. m_requiresRayTracingSceneSrg = (rayTracingSceneSrgLayout != nullptr);
  147. // build the ray tracing pipeline state descriptor
  148. RHI::RayTracingPipelineStateDescriptor descriptor;
  149. descriptor.m_pipelineState = m_globalPipelineState.get();
  150. descriptor.m_configuration.m_maxPayloadSize = m_passData->m_maxPayloadSize;
  151. descriptor.m_configuration.m_maxAttributeSize = m_passData->m_maxAttributeSize;
  152. descriptor.m_configuration.m_maxRecursionDepth = m_passData->m_maxRecursionDepth;
  153. for (auto& shaderLib : shaderLibs)
  154. {
  155. RHI::RayTracingShaderLibrary& shaderLibrary = descriptor.m_shaderLibraries.emplace_back();
  156. shaderLibrary.m_descriptor = shaderLib.m_pipelineStateDescriptor;
  157. if (!shaderLib.m_rayGenerationShaderName.IsEmpty())
  158. {
  159. shaderLibrary.m_rayGenerationShaderName = Name(m_passData->m_rayGenerationShaderName);
  160. }
  161. if (!shaderLib.m_closestHitShaderName.IsEmpty())
  162. {
  163. shaderLibrary.m_closestHitShaderName = Name(m_passData->m_closestHitShaderName);
  164. }
  165. if (!shaderLib.m_closestHitProceduralShaderName.IsEmpty())
  166. {
  167. shaderLibrary.m_closestHitShaderName = Name(m_passData->m_closestHitProceduralShaderName);
  168. }
  169. if (!shaderLib.m_missShaderName.IsEmpty())
  170. {
  171. shaderLibrary.m_missShaderName = Name(m_passData->m_missShaderName);
  172. }
  173. }
  174. descriptor.AddHitGroup(Name("HitGroup"), Name(m_passData->m_closestHitShaderName));
  175. RayTracingFeatureProcessor* rayTracingFeatureProcessor =
  176. GetScene() ? GetScene()->GetFeatureProcessor<RayTracingFeatureProcessor>() : nullptr;
  177. if (rayTracingFeatureProcessor && !m_passData->m_closestHitProceduralShaderName.empty())
  178. {
  179. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  180. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  181. {
  182. auto shaderVariant{ it->m_intersectionShader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
  183. AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
  184. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  185. descriptor.AddIntersectionShaderLibrary(pipelineStateDescriptor, it->m_intersectionShaderName);
  186. descriptor.AddHitGroup(it->m_name, Name(m_passData->m_closestHitProceduralShaderName), it->m_intersectionShaderName);
  187. }
  188. }
  189. // create the ray tracing pipeline state object
  190. m_rayTracingPipelineState = aznew RHI::RayTracingPipelineState;
  191. m_rayTracingPipelineState->Init(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), descriptor);
  192. // register the ray tracing and global pipeline state object with the dispatch-item
  193. m_dispatchRaysItem.SetRayTracingPipelineState(m_rayTracingPipelineState.get());
  194. m_dispatchRaysItem.SetPipelineState(m_globalPipelineState.get());
  195. // make sure the shader table rebuilds if we're hotreloading
  196. m_rayTracingShaderTableRevision = 0;
  197. // store the max ray length
  198. m_maxRayLength = m_passData->m_maxRayLength;
  199. RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
  200. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_rayGenerationShaderAssetReference.m_assetId);
  201. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitShaderAssetReference.m_assetId);
  202. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitProceduralShaderAssetReference.m_assetId);
  203. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_missShaderAssetReference.m_assetId);
  204. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_intersectionShaderAssetReference.m_assetId);
  205. }
  206. bool RayTracingPass::IsEnabled() const
  207. {
  208. if (!RenderPass::IsEnabled())
  209. {
  210. return false;
  211. }
  212. if (m_pipeline == nullptr)
  213. {
  214. return false;
  215. }
  216. RPI::Scene* scene = m_pipeline->GetScene();
  217. if (!scene)
  218. {
  219. return false;
  220. }
  221. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  222. if (!rayTracingFeatureProcessor)
  223. {
  224. return false;
  225. }
  226. return true;
  227. }
  228. void RayTracingPass::BuildInternal()
  229. {
  230. if (m_indirectDispatch)
  231. {
  232. if (!m_indirectDispatchRaysBufferSignature)
  233. {
  234. AZ::RHI::IndirectBufferLayout bufferLayout;
  235. bufferLayout.AddIndirectCommand(AZ::RHI::IndirectCommandDescriptor(AZ::RHI::IndirectCommandType::DispatchRays));
  236. if (!bufferLayout.Finalize())
  237. {
  238. AZ_Assert(false, "Fail to finalize Indirect Layout");
  239. }
  240. m_indirectDispatchRaysBufferSignature = aznew AZ::RHI::IndirectBufferSignature();
  241. AZ::RHI::IndirectBufferSignatureDescriptor signatureDescriptor{};
  242. signatureDescriptor.m_layout = bufferLayout;
  243. [[maybe_unused]] auto result = m_indirectDispatchRaysBufferSignature->Init(
  244. AZ::RHI::RHISystemInterface::Get()->GetRayTracingSupport(), signatureDescriptor);
  245. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Fail to initialize Indirect Buffer Signature");
  246. }
  247. m_indirectDispatchRaysBufferBinding = nullptr;
  248. if (!m_indirectDispatchBufferSlotName.IsEmpty())
  249. {
  250. m_indirectDispatchRaysBufferBinding = FindAttachmentBinding(m_indirectDispatchBufferSlotName);
  251. AZ_Assert(m_indirectDispatchRaysBufferBinding,
  252. "[RaytracingPass '%s']: Indirect dispatch buffer slot %s not found.",
  253. GetPathName().GetCStr(),
  254. m_indirectDispatchBufferSlotName.GetCStr());
  255. if (m_indirectDispatchRaysBufferBinding)
  256. {
  257. AZ_Assert(
  258. m_indirectDispatchRaysBufferBinding->m_scopeAttachmentUsage == AZ::RHI::ScopeAttachmentUsage::Indirect,
  259. "[RaytracingPass '%s']: Indirect dispatch buffer slot %s needs ScopeAttachmentUsage::Indirect.",
  260. GetPathName().GetCStr(),
  261. m_indirectDispatchBufferSlotName.GetCStr())
  262. }
  263. }
  264. else
  265. {
  266. for (auto& binding : m_attachmentBindings)
  267. {
  268. if (binding.m_scopeAttachmentUsage == AZ::RHI::ScopeAttachmentUsage::Indirect)
  269. {
  270. m_indirectDispatchRaysBufferBinding = &binding;
  271. break;
  272. }
  273. }
  274. AZ_Assert(m_indirectDispatchRaysBufferBinding,
  275. "[RaytracingPass '%s']: No valid indirect dispatch buffer slot found.",
  276. GetPathName().GetCStr());
  277. }
  278. if (!m_dispatchRaysIndirectBuffer)
  279. {
  280. m_dispatchRaysIndirectBuffer =
  281. aznew AZ::RHI::DispatchRaysIndirectBuffer{ AZ::RHI::RHISystemInterface::Get()->GetRayTracingSupport() };
  282. m_dispatchRaysIndirectBuffer->Init(
  283. AZ::RPI::BufferSystemInterface::Get()->GetCommonBufferPool(AZ::RPI::CommonBufferPoolType::Indirect).get());
  284. }
  285. }
  286. else if (m_fullscreenDispatch)
  287. {
  288. m_fullscreenSizeSourceBinding = nullptr;
  289. if (!m_fullscreenSizeSourceSlotName.IsEmpty())
  290. {
  291. m_fullscreenSizeSourceBinding = FindAttachmentBinding(m_fullscreenSizeSourceSlotName);
  292. AZ_Assert(
  293. m_fullscreenSizeSourceBinding,
  294. "[RaytracingPass '%s']: Fullscreen size source slot %s not found.",
  295. GetPathName().GetCStr(),
  296. m_fullscreenSizeSourceSlotName.GetCStr());
  297. }
  298. else
  299. {
  300. if (GetOutputCount() > 0)
  301. {
  302. m_fullscreenSizeSourceBinding = &GetOutputBinding(0);
  303. }
  304. else if (!m_fullscreenSizeSourceBinding && GetInputOutputCount() > 0)
  305. {
  306. m_fullscreenSizeSourceBinding = &GetInputOutputBinding(0);
  307. }
  308. AZ_Assert(
  309. m_fullscreenSizeSourceBinding,
  310. "[RaytracingPass '%s']: No valid Output or InputOutput slot as a fullscreen size source found.",
  311. GetPathName().GetCStr());
  312. }
  313. }
  314. }
  315. void RayTracingPass::FrameBeginInternal(FramePrepareParams params)
  316. {
  317. RPI::Scene* scene = m_pipeline->GetScene();
  318. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  319. if (!rayTracingFeatureProcessor)
  320. {
  321. return;
  322. }
  323. RPI::RenderPass::FrameBeginInternal(params);
  324. }
  325. void RayTracingPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  326. {
  327. RPI::Scene* scene = m_pipeline->GetScene();
  328. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  329. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  330. RPI::RenderPass::SetupFrameGraphDependencies(frameGraph);
  331. frameGraph.SetEstimatedItemCount(1);
  332. // TLAS
  333. {
  334. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer();
  335. if (rayTracingTlasBuffer)
  336. {
  337. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  338. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  339. {
  340. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  341. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  342. }
  343. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer()->GetDescriptor().m_byteCount);
  344. RHI::BufferViewDescriptor tlasBufferViewDescriptor =
  345. RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  346. RHI::BufferScopeAttachmentDescriptor desc;
  347. desc.m_attachmentId = tlasAttachmentId;
  348. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  349. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  350. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
  351. }
  352. }
  353. }
  354. void RayTracingPass::CompileResources(const RHI::FrameGraphCompileContext& context)
  355. {
  356. RPI::Scene* scene = m_pipeline->GetScene();
  357. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  358. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  359. if (m_indirectDispatch)
  360. {
  361. if (m_indirectDispatchRaysBufferBinding)
  362. {
  363. auto& attachment{ m_indirectDispatchRaysBufferBinding->GetAttachment() };
  364. AZ_Assert(
  365. attachment,
  366. "[RayTracingPass '%s']: Indirect dispatch buffer slot %s has no attachment.",
  367. GetPathName().GetCStr(),
  368. m_indirectDispatchRaysBufferBinding->m_name.GetCStr());
  369. if (attachment)
  370. {
  371. auto* indirectDispatchBuffer{ context.GetBuffer(attachment->GetAttachmentId()) };
  372. m_indirectDispatchRaysBufferView = AZ::RHI::IndirectBufferView{ *indirectDispatchBuffer,
  373. *m_indirectDispatchRaysBufferSignature,
  374. 0,
  375. sizeof(DispatchRaysIndirectCommand),
  376. sizeof(DispatchRaysIndirectCommand) };
  377. RHI::DispatchRaysIndirect dispatchRaysArgs(
  378. 1, m_indirectDispatchRaysBufferView, 0, m_dispatchRaysIndirectBuffer.get());
  379. m_dispatchRaysItem.SetArguments(dispatchRaysArgs);
  380. }
  381. }
  382. }
  383. else if (m_fullscreenDispatch)
  384. {
  385. auto& attachment = m_fullscreenSizeSourceBinding->GetAttachment();
  386. AZ_Assert(
  387. attachment,
  388. "[RaytracingPass '%s']: Slot %s has no attachment for fullscreen size source.",
  389. GetPathName().GetCStr(),
  390. m_fullscreenSizeSourceBinding->m_name.GetCStr());
  391. AZ::RHI::DispatchRaysDirect dispatchRaysArgs;
  392. if (attachment)
  393. {
  394. AZ_Assert(
  395. attachment->GetAttachmentType() == AZ::RHI::AttachmentType::Image,
  396. "[RaytracingPass '%s']: Slot %s must be an image for fullscreen size source.",
  397. GetPathName().GetCStr(),
  398. m_fullscreenSizeSourceBinding->m_name.GetCStr());
  399. auto imageDescriptor = context.GetImageDescriptor(attachment->GetAttachmentId());
  400. dispatchRaysArgs.m_width = imageDescriptor.m_size.m_width;
  401. dispatchRaysArgs.m_height = imageDescriptor.m_size.m_height;
  402. dispatchRaysArgs.m_depth = imageDescriptor.m_size.m_depth;
  403. }
  404. m_dispatchRaysItem.SetArguments(dispatchRaysArgs);
  405. }
  406. else
  407. {
  408. AZ::RHI::DispatchRaysDirect dispatchRaysArgs{ m_passData->m_threadCountX,
  409. m_passData->m_threadCountY,
  410. m_passData->m_threadCountZ };
  411. m_dispatchRaysItem.SetArguments(dispatchRaysArgs);
  412. }
  413. uint32_t proceduralGeometryTypeRevision = rayTracingFeatureProcessor->GetProceduralGeometryTypeRevision();
  414. if (m_proceduralGeometryTypeRevision != proceduralGeometryTypeRevision)
  415. {
  416. CreatePipelineState();
  417. RPI::SceneNotificationBus::Event(
  418. GetScene()->GetId(),
  419. &RPI::SceneNotification::OnRenderPipelineChanged,
  420. GetRenderPipeline(),
  421. RPI::SceneNotification::RenderPipelineChangeType::PassChanged);
  422. m_proceduralGeometryTypeRevision = proceduralGeometryTypeRevision;
  423. }
  424. if (!m_rayTracingShaderTable || m_rayTracingShaderTableRevision != rayTracingFeatureProcessor->GetRevision())
  425. {
  426. // scene changed, need to rebuild the shader table
  427. m_rayTracingShaderTableRevision = rayTracingFeatureProcessor->GetRevision();
  428. m_rayTracingShaderTable = aznew AZ::RHI::RayTracingShaderTable();
  429. m_rayTracingShaderTable->Init(
  430. AZ::RHI::RHISystemInterface::Get()->GetRayTracingSupport(), rayTracingFeatureProcessor->GetBufferPools());
  431. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  432. if (rayTracingFeatureProcessor->HasGeometry())
  433. {
  434. // build the ray tracing shader table descriptor
  435. descriptor->m_name = Name("RayTracingShaderTable");
  436. descriptor->m_rayTracingPipelineState = m_rayTracingPipelineState;
  437. descriptor->m_rayGenerationRecord.emplace_back(Name(m_passData->m_rayGenerationShaderName));
  438. descriptor->m_missRecords.emplace_back(Name(m_passData->m_missShaderName));
  439. // add a hit group for standard meshes mesh to the shader table
  440. descriptor->m_hitGroupRecords.emplace_back(Name("HitGroup"));
  441. // add a hit group for each procedural geometry type to the shader table
  442. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  443. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  444. {
  445. descriptor->m_hitGroupRecords.emplace_back(it->m_name);
  446. // TODO(intersection): Set per-hitgroup SRG once RayTracingPipelineState supports local root signatures
  447. }
  448. }
  449. m_rayTracingShaderTable->Build(descriptor);
  450. // register the shader-table with the dispatch item
  451. m_dispatchRaysItem.SetRayTracingPipelineState(m_rayTracingPipelineState.get());
  452. m_dispatchRaysItem.SetRayTracingShaderTable(m_rayTracingShaderTable.get());
  453. }
  454. // Collect and register the Srgs (RayTracingGlobal, RayTracingScene, ViewSrg, SceneSrg and RayTracingMaterialSrg)
  455. // The more consistent way would be to call BindSrg() of the RenderPass, and then call
  456. // SetSrgsForDispatchRays() in BuildCommandListInternal, but that function doesn't exist.
  457. // [GFX TODO][ATOM-15610] Add RenderPass::SetSrgsForRayTracingDispatch
  458. if (m_shaderResourceGroup != nullptr)
  459. {
  460. m_shaderResourceGroup->SetConstant(m_maxRayLengthInputIndex, m_maxRayLength);
  461. BindPassSrg(context, m_shaderResourceGroup);
  462. m_shaderResourceGroup->Compile();
  463. m_rayTracingSRGsToBind.push_back(m_shaderResourceGroup->GetRHIShaderResourceGroup());
  464. }
  465. if (m_requiresRayTracingSceneSrg)
  466. {
  467. m_rayTracingSRGsToBind.push_back(rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup());
  468. }
  469. if (m_requiresViewSrg)
  470. {
  471. RPI::ViewPtr view = m_pipeline->GetFirstView(GetPipelineViewTag());
  472. if (view)
  473. {
  474. m_rayTracingSRGsToBind.push_back(view->GetShaderResourceGroup()->GetRHIShaderResourceGroup());
  475. }
  476. }
  477. if (m_requiresSceneSrg)
  478. {
  479. m_rayTracingSRGsToBind.push_back(scene->GetShaderResourceGroup()->GetRHIShaderResourceGroup());
  480. }
  481. if (m_requiresRayTracingMaterialSrg)
  482. {
  483. m_rayTracingSRGsToBind.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup());
  484. }
  485. }
  486. void RayTracingPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  487. {
  488. RPI::Scene* scene = m_pipeline->GetScene();
  489. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  490. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  491. AZ_Assert(
  492. RHI::CheckBit(rayTracingFeatureProcessor->GetDeviceMask(), context.GetDeviceIndex()),
  493. "RayTracingPass cannot run on a device without a RayTracingAccelerationStructurePass");
  494. if (!rayTracingFeatureProcessor || !rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer() ||
  495. !rayTracingFeatureProcessor->HasGeometry() || !m_rayTracingShaderTable)
  496. {
  497. return;
  498. }
  499. if (m_dispatchRaysShaderTableRevision != m_rayTracingShaderTableRevision)
  500. {
  501. m_dispatchRaysShaderTableRevision = m_rayTracingShaderTableRevision;
  502. if (m_dispatchRaysIndirectBuffer)
  503. {
  504. m_dispatchRaysIndirectBuffer->Build(m_rayTracingShaderTable.get());
  505. }
  506. }
  507. // TODO: change this to BindSrgsForDispatchRays() as soon as it exists
  508. // IMPORTANT: The data in shaderResourceGroups must be sorted by (entry)->GetBindingSlot() (FrequencyId value in SRG source file
  509. // from SrgSemantics.azsli) in order for them to be correctly assigned by Vulkan
  510. AZStd::sort(
  511. m_rayTracingSRGsToBind.begin(),
  512. m_rayTracingSRGsToBind.end(),
  513. [](const auto& lhs, const auto& rhs)
  514. {
  515. return lhs->GetBindingSlot() < rhs->GetBindingSlot();
  516. });
  517. m_dispatchRaysItem.SetShaderResourceGroups(m_rayTracingSRGsToBind.data(), static_cast<uint32_t>(m_rayTracingSRGsToBind.size()));
  518. // submit the DispatchRays item
  519. context.GetCommandList()->Submit(m_dispatchRaysItem.GetDeviceDispatchRaysItem(context.GetDeviceIndex()));
  520. }
  521. void RayTracingPass::FrameEndInternal()
  522. {
  523. m_rayTracingSRGsToBind.clear();
  524. }
  525. void RayTracingPass::OnShaderReinitialized([[maybe_unused]] const RPI::Shader& shader)
  526. {
  527. CreatePipelineState();
  528. }
  529. void RayTracingPass::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset<RPI::ShaderAsset>& shaderAsset)
  530. {
  531. CreatePipelineState();
  532. }
  533. void RayTracingPass::OnShaderVariantReinitialized(const RPI::ShaderVariant&)
  534. {
  535. CreatePipelineState();
  536. }
  537. } // namespace Render
  538. } // namespace AZ