Shader.cpp 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RPI.Public/Shader/Shader.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <Atom/RHI/PipelineStateCache.h>
  11. #include <Atom/RHI/RHISystemInterface.h>
  12. #include <AtomCore/Instance/InstanceDatabase.h>
  13. #include <Atom/RPI.Public/Shader/ShaderReloadDebugTracker.h>
  14. #include <Atom/RPI.Public/Shader/ShaderSystemInterface.h>
  15. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  16. #include <AzCore/Interface/Interface.h>
  17. #include <AzCore/std/time.h>
  18. #include <AzCore/Component/TickBus.h>
  19. #define PSOCacheVersion 0 // Bump this if you want to reset PSO cache for everyone
  20. namespace AZ
  21. {
  22. namespace RPI
  23. {
  24. Data::Instance<Shader> Shader::FindOrCreate(const Data::Asset<ShaderAsset>& shaderAsset, const Name& supervariantName)
  25. {
  26. auto anySupervariantName = AZStd::any(supervariantName);
  27. // retrieve the supervariant index from the shader asset
  28. SupervariantIndex supervariantIndex = shaderAsset->GetSupervariantIndex(supervariantName);
  29. if (!supervariantIndex.IsValid())
  30. {
  31. AZ_Error("Shader", false, "Supervariant with name %s, was not found in shader %s", supervariantName.GetCStr(), shaderAsset->GetName().GetCStr());
  32. return nullptr;
  33. }
  34. // Create the instance ID using the shader asset with an additional unique identifier from the Super variant index.
  35. const Data::InstanceId instanceId =
  36. Data::InstanceId::CreateFromAsset(shaderAsset, { supervariantIndex.GetIndex() });
  37. // retrieve the shader instance from the Instance database
  38. return Data::InstanceDatabase<Shader>::Instance().FindOrCreate(instanceId, shaderAsset, &anySupervariantName);
  39. }
  40. Data::Instance<Shader> Shader::FindOrCreate(const Data::Asset<ShaderAsset>& shaderAsset)
  41. {
  42. return FindOrCreate(shaderAsset, AZ::Name{ "" });
  43. }
  44. Data::Instance<Shader> Shader::CreateInternal([[maybe_unused]] ShaderAsset& shaderAsset, const AZStd::any* anySupervariantName)
  45. {
  46. AZ_Assert(anySupervariantName != nullptr, "Invalid supervariant name param");
  47. auto supervariantName = AZStd::any_cast<AZ::Name>(*anySupervariantName);
  48. auto supervariantIndex = shaderAsset.GetSupervariantIndex(supervariantName);
  49. if (!supervariantIndex.IsValid())
  50. {
  51. AZ_Error("Shader", false, "Supervariant with name %s, was not found in shader %s", supervariantName.GetCStr(), shaderAsset.GetName().GetCStr());
  52. return nullptr;
  53. }
  54. Data::Instance<Shader> shader = aznew Shader(supervariantIndex);
  55. const RHI::ResultCode resultCode = shader->Init(shaderAsset);
  56. if (resultCode != RHI::ResultCode::Success)
  57. {
  58. return nullptr;
  59. }
  60. return shader;
  61. }
  62. Shader::~Shader()
  63. {
  64. Shutdown();
  65. }
  66. static bool GetPipelineLibraryPaths(
  67. AZStd::unordered_map<int, AZStd::string>& pipelineLibraryPaths,
  68. size_t pipelineLibraryPathLength,
  69. const ShaderAsset& shaderAsset)
  70. {
  71. if (auto* fileIOBase = IO::FileIOBase::GetInstance())
  72. {
  73. const Data::AssetId& assetId = shaderAsset.GetId();
  74. Name platformName = RHI::Factory::Get().GetName();
  75. Name shaderName = shaderAsset.GetName();
  76. AZStd::string uuidString;
  77. assetId.m_guid.ToString<AZStd::string>(uuidString, false, false);
  78. AZStd::string configString;
  79. if (RHI::BuildOptions::IsDebugBuild)
  80. {
  81. configString = "Debug";
  82. }
  83. else if (RHI::BuildOptions::IsProfileBuild)
  84. {
  85. configString = "Profile";
  86. }
  87. else
  88. {
  89. configString = "Release";
  90. }
  91. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  92. for (int deviceIndex = 0; deviceIndex < deviceCount; ++deviceIndex)
  93. {
  94. RHI::PhysicalDeviceDescriptor devicePhysicalDeviceDesc{
  95. RHI::RHISystemInterface::Get()->GetDevice(deviceIndex)->GetPhysicalDevice().GetDescriptor()
  96. };
  97. char pipelineLibraryPathTemp[AZ_MAX_PATH_LEN];
  98. azsnprintf(
  99. pipelineLibraryPathTemp,
  100. AZ_MAX_PATH_LEN,
  101. "@user@/Atom/PipelineStateCache_%s_%u_%u_%s_Ver_%i/%s/%s_%s_%d",
  102. ToString(devicePhysicalDeviceDesc.m_vendorId).data(),
  103. devicePhysicalDeviceDesc.m_deviceId,
  104. devicePhysicalDeviceDesc.m_driverVersion,
  105. configString.data(),
  106. PSOCacheVersion,
  107. platformName.GetCStr(),
  108. shaderName.GetCStr(),
  109. uuidString.data(),
  110. assetId.m_subId);
  111. char resolvedPipelineLibraryPath[AZ_MAX_PATH_LEN];
  112. fileIOBase->ResolvePath(pipelineLibraryPathTemp, resolvedPipelineLibraryPath, pipelineLibraryPathLength);
  113. pipelineLibraryPaths[deviceIndex] = resolvedPipelineLibraryPath;
  114. }
  115. return true;
  116. }
  117. return false;
  118. }
  119. RHI::ResultCode Shader::Init(ShaderAsset& shaderAsset)
  120. {
  121. Data::AssetBus::MultiHandler::BusDisconnect();
  122. ShaderVariantFinderNotificationBus::Handler::BusDisconnect();
  123. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  124. RHI::DrawListTagRegistry* drawListTagRegistry = rhiSystem->GetDrawListTagRegistry();
  125. m_asset = { &shaderAsset, AZ::Data::AssetLoadBehavior::PreLoad };
  126. m_pipelineStateType = shaderAsset.GetPipelineStateType();
  127. GetPipelineLibraryPaths(m_pipelineLibraryPaths, AZ_MAX_PATH_LEN, *m_asset);
  128. {
  129. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  130. m_shaderVariants.clear();
  131. }
  132. auto rootShaderVariantAsset = shaderAsset.GetRootVariantAsset(m_supervariantIndex);
  133. m_rootVariant.Init(m_asset, rootShaderVariantAsset, m_supervariantIndex);
  134. if (m_pipelineLibraryHandle.IsNull())
  135. {
  136. // We set up a pipeline library only once for the lifetime of the Shader instance.
  137. // This should allow the Shader to be reloaded at runtime many times, and cache and reuse PipelineState objects rather than rebuild them.
  138. // It also fixes a particular TDR crash that occurred on some hardware when hot-reloading shaders and building pipeline states
  139. // in a new pipeline library every time.
  140. RHI::PipelineStateCache* pipelineStateCache = rhiSystem->GetPipelineStateCache();
  141. auto serializedData = LoadPipelineLibrary();
  142. RHI::PipelineLibraryHandle pipelineLibraryHandle =
  143. pipelineStateCache->CreateLibrary(serializedData, m_pipelineLibraryPaths);
  144. if (pipelineLibraryHandle.IsNull())
  145. {
  146. AZ_Error("Shader", false, "Failed to create pipeline library from pipeline state cache.");
  147. return RHI::ResultCode::Fail;
  148. }
  149. m_pipelineLibraryHandle = pipelineLibraryHandle;
  150. m_pipelineStateCache = pipelineStateCache;
  151. }
  152. const Name& drawListName = shaderAsset.GetDrawListName();
  153. if (!drawListName.IsEmpty())
  154. {
  155. m_drawListTag = drawListTagRegistry->AcquireTag(drawListName);
  156. if (!m_drawListTag.IsValid())
  157. {
  158. AZ_Error("Shader", false, "Failed to acquire a DrawListTag. Entries are full.");
  159. }
  160. }
  161. ShaderVariantFinderNotificationBus::Handler::BusConnect(m_asset.GetId());
  162. m_reloadedAssets.clear();
  163. const auto& supervariants = m_asset->GetCurrentShaderApiData().m_supervariants;
  164. m_expectedAssetReloadCount = 1 /*m_asset*/ + supervariants.size();
  165. Data::AssetBus::MultiHandler::BusConnect(m_asset.GetId());
  166. for (const auto& supervariant : supervariants)
  167. {
  168. Data::AssetBus::MultiHandler::BusConnect(supervariant.m_rootShaderVariantAsset.GetId());
  169. }
  170. return RHI::ResultCode::Success;
  171. }
  172. void Shader::Shutdown()
  173. {
  174. ShaderVariantFinderNotificationBus::Handler::BusDisconnect();
  175. Data::AssetBus::MultiHandler::BusDisconnect();
  176. if (m_pipelineLibraryHandle.IsValid())
  177. {
  178. if (r_enablePsoCaching)
  179. {
  180. SavePipelineLibrary();
  181. }
  182. m_pipelineStateCache->ReleaseLibrary(m_pipelineLibraryHandle);
  183. m_pipelineStateCache = nullptr;
  184. m_pipelineLibraryHandle = {};
  185. }
  186. if (m_drawListTag.IsValid())
  187. {
  188. RHI::DrawListTagRegistry* drawListTagRegistry = RHI::RHISystemInterface::Get()->GetDrawListTagRegistry();
  189. drawListTagRegistry->ReleaseTag(m_drawListTag);
  190. m_drawListTag.Reset();
  191. }
  192. }
  193. ///////////////////////////////////////////////////////////////////////
  194. // AssetBus overrides
  195. void Shader::OnAssetReloaded(Data::Asset<Data::AssetData> asset)
  196. {
  197. ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Shader::OnAssetReloaded %s.\n",
  198. this, asset.GetHint().c_str());
  199. m_reloadedAssets.emplace(asset.GetId(), asset);
  200. if (ShaderReloadDebugTracker::IsEnabled())
  201. {
  202. ShaderReloadDebugTracker::Printf(
  203. "Current ShaderAssetPtr={%p} with RootVariantAssetPtr={%p}", m_asset.Get(), m_asset->GetRootVariantAsset().Get());
  204. ShaderReloadDebugTracker::Printf("{%p} -> Shader::OnAssetReloaded so far only %zu of %zu assets have been reloaded.",
  205. this, m_reloadedAssets.size(), m_expectedAssetReloadCount);
  206. AZStd::sys_time_t now = AZStd::GetTimeUTCMilliSecond();
  207. if (asset.GetType() == AZ::AzTypeInfo<ShaderVariantAsset>::Uuid())
  208. {
  209. ShaderReloadDebugTracker::Printf(
  210. "{%p}->Shader::OnRootVariantReloaded [current time %lld] got new variant {%p}'%s'",
  211. this,
  212. now,
  213. asset.Get(),
  214. asset.GetHint().c_str());
  215. }
  216. else
  217. {
  218. const auto* newShaderAsset = asset.GetAs<ShaderAsset>();
  219. const auto shaderVariantAsset = newShaderAsset->GetRootVariantAsset();
  220. ShaderReloadDebugTracker::Printf(
  221. "{%p}->Shader::OnShaderAssetReloaded [current time %lld] got new shader {%p}'%s' with included variant {%p}'%s'",
  222. this,
  223. now,
  224. newShaderAsset,
  225. asset.GetHint().c_str(),
  226. shaderVariantAsset.Get(),
  227. shaderVariantAsset.GetHint().c_str());
  228. }
  229. }
  230. if (m_reloadedAssets.size() != m_expectedAssetReloadCount)
  231. {
  232. return;
  233. }
  234. // Time to update all references:
  235. auto itor = m_reloadedAssets.find(m_asset.GetId());
  236. if (itor == m_reloadedAssets.end())
  237. {
  238. AZ_Error("Shader", false, "Can not find the reloaded ShaderAsset with ID '%s'. Hint '%s'",
  239. m_asset.GetId().ToString<AZStd::string>().c_str(),
  240. m_asset.GetHint().c_str());
  241. return;
  242. }
  243. m_asset = itor->second;
  244. m_reloadedAssets.erase(itor);
  245. for (auto& [assetId, rootVariantAsset] : m_reloadedAssets)
  246. {
  247. AZ_Assert(rootVariantAsset.GetType() == AZ::AzTypeInfo<ShaderVariantAsset>::Uuid(),
  248. "Was expecting only ShaderVariantAsset(s)");
  249. if (!m_asset->UpdateRootShaderVariantAsset(Data::static_pointer_cast<ShaderVariantAsset>(rootVariantAsset)))
  250. {
  251. AZ_Error("Shader", false,
  252. "Failed to update Root ShaderVariantAsset {%p}'%s'",
  253. rootVariantAsset.Get(),
  254. rootVariantAsset.GetHint().c_str());
  255. }
  256. }
  257. m_reloadedAssets.clear();
  258. Init(*m_asset.Get());
  259. ShaderReloadNotificationBus::Event(asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderReinitialized, *this);
  260. }
  261. ///////////////////////////////////////////////////////////////////////
  262. ///////////////////////////////////////////////////////////////////
  263. /// ShaderVariantFinderNotificationBus overrides
  264. void Shader::OnShaderVariantAssetReady(Data::Asset<ShaderVariantAsset> shaderVariantAsset, bool isError)
  265. {
  266. ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Shader::OnShaderVariantAssetReady %s", this, shaderVariantAsset.GetHint().c_str());
  267. AZ_Assert(shaderVariantAsset, "Reloaded ShaderVariantAsset is null");
  268. const ShaderVariantStableId stableId = shaderVariantAsset->GetStableId();
  269. // check the supervariantIndex of the ShaderVariantAsset to make sure it matches the supervariantIndex of this shader instance
  270. if (shaderVariantAsset->GetSupervariantIndex() != m_supervariantIndex.GetIndex())
  271. {
  272. return;
  273. }
  274. // We make a copy of the updated variant because OnShaderVariantReinitialized must not be called inside
  275. // m_variantCacheMutex or deadlocks may occur.
  276. // Or if there is an error, we leave this object in its default state to indicate there was an error.
  277. // [GFX TODO] We really should have a dedicated message/event for this, but that will be covered by a future task where
  278. // we will merge ShaderReloadNotificationBus messages into one. For now, we just indicate the error by passing an empty ShaderVariant,
  279. // all our call sites don't use this data anyway.
  280. ShaderVariant updatedVariant;
  281. if (isError)
  282. {
  283. //Remark: We do not assert if the stableId == RootShaderVariantStableId, because we can not trust in the asset data
  284. //on error. so it is possible that on error the stbleId == RootShaderVariantStableId;
  285. if (stableId == RootShaderVariantStableId)
  286. {
  287. return;
  288. }
  289. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  290. m_shaderVariants.erase(stableId);
  291. }
  292. else
  293. {
  294. AZ_Assert(stableId != RootShaderVariantStableId,
  295. "The root variant is expected to be updated by the ShaderAsset.");
  296. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  297. auto iter = m_shaderVariants.find(stableId);
  298. if (iter != m_shaderVariants.end())
  299. {
  300. ShaderVariant& shaderVariant = iter->second;
  301. if (!shaderVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex))
  302. {
  303. AZ_Error("Shader", false, "Failed to init shaderVariant with StableId=%u", shaderVariantAsset->GetStableId());
  304. m_shaderVariants.erase(stableId);
  305. }
  306. else
  307. {
  308. updatedVariant = shaderVariant;
  309. }
  310. }
  311. else
  312. {
  313. //This is the first time the shader variant asset comes to life.
  314. updatedVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex);
  315. m_shaderVariants.emplace(stableId, updatedVariant);
  316. }
  317. }
  318. // [GFX TODO] It might make more sense to call OnShaderReinitialized here
  319. ShaderReloadNotificationBus::Event(m_asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, updatedVariant);
  320. }
  321. ///////////////////////////////////////////////////////////////////
  322. AZStd::unordered_map<int, ConstPtr<RHI::PipelineLibraryData>> Shader::LoadPipelineLibrary() const
  323. {
  324. AZStd::unordered_map<int, ConstPtr<RHI::PipelineLibraryData>> pipelineLibraries;
  325. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  326. for (int deviceIndex = 0; deviceIndex < deviceCount; ++deviceIndex)
  327. {
  328. pipelineLibraries[deviceIndex] =
  329. Utils::LoadObjectFromFile<RHI::PipelineLibraryData>(m_pipelineLibraryPaths.at(deviceIndex));
  330. }
  331. return pipelineLibraries;
  332. }
  333. void Shader::SavePipelineLibrary() const
  334. {
  335. if (!m_pipelineLibraryPaths.empty())
  336. {
  337. RHI::ConstPtr<RHI::PipelineLibrary> pipelineLibrary = m_pipelineStateCache->GetMergedLibrary(m_pipelineLibraryHandle);
  338. if (!pipelineLibrary)
  339. {
  340. return;
  341. }
  342. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  343. for (int deviceIndex = 0; deviceIndex < deviceCount; ++deviceIndex)
  344. {
  345. RHI::Device* device = RHI::RHISystemInterface::Get()->GetDevice(deviceIndex);
  346. RHI::ConstPtr<RHI::DevicePipelineLibrary> pipelineLib = pipelineLibrary->GetDevicePipelineLibrary(deviceIndex);
  347. // Check if explicit file load/save operation is needed as the RHI backend api may not support it
  348. if (device->GetFeatures().m_isPsoCacheFileOperationsNeeded)
  349. {
  350. RHI::ConstPtr<RHI::PipelineLibraryData> serializedData = pipelineLib->GetSerializedData();
  351. if (serializedData)
  352. {
  353. Utils::SaveObjectToFile<RHI::PipelineLibraryData>(
  354. m_pipelineLibraryPaths.at(deviceIndex), DataStream::ST_BINARY, serializedData.get());
  355. }
  356. }
  357. else
  358. {
  359. [[maybe_unused]] bool result = pipelineLib->SaveSerializedData(m_pipelineLibraryPaths.at(deviceIndex));
  360. AZ_Error("Shader", result, "Pipeline Library %s was not saved", &m_pipelineLibraryPaths.at(deviceIndex));
  361. }
  362. }
  363. }
  364. }
  365. ShaderOptionGroup Shader::CreateShaderOptionGroup() const
  366. {
  367. return ShaderOptionGroup(m_asset->GetShaderOptionGroupLayout());
  368. }
  369. const ShaderVariant& Shader::GetVariant(const ShaderVariantId& shaderVariantId)
  370. {
  371. Data::Asset<ShaderVariantAsset> shaderVariantAsset = m_asset->GetVariantAsset(shaderVariantId, m_supervariantIndex);
  372. if (!shaderVariantAsset || shaderVariantAsset->IsRootVariant())
  373. {
  374. return m_rootVariant;
  375. }
  376. return GetVariant(shaderVariantAsset->GetStableId());
  377. }
  378. const ShaderVariant& Shader::GetRootVariant()
  379. {
  380. return m_rootVariant;
  381. }
  382. const ShaderVariant& Shader::GetDefaultVariant()
  383. {
  384. ShaderOptionGroup defaultOptions = GetDefaultShaderOptions();
  385. return GetVariant(defaultOptions.GetShaderVariantId());
  386. }
  387. ShaderOptionGroup Shader::GetDefaultShaderOptions() const
  388. {
  389. return m_asset->GetDefaultShaderOptions();
  390. }
  391. ShaderVariantSearchResult Shader::FindVariantStableId(const ShaderVariantId& shaderVariantId) const
  392. {
  393. ShaderVariantSearchResult variantSearchResult = m_asset->FindVariantStableId(shaderVariantId);
  394. return variantSearchResult;
  395. }
  396. const ShaderVariant& Shader::GetVariant(ShaderVariantStableId shaderVariantStableId)
  397. {
  398. const ShaderVariant& variant = GetVariantInternal(shaderVariantStableId);
  399. if (ShaderReloadDebugTracker::IsEnabled())
  400. {
  401. AZStd::sys_time_t now = AZStd::GetTimeUTCMilliSecond();
  402. ShaderReloadDebugTracker::Printf("{%p}->Shader::GetVariant for shader '%s' [current time %lld] found variant '%s'",
  403. this, m_asset.GetHint().c_str(), now, variant.GetShaderVariantAsset().GetHint().c_str());
  404. }
  405. return variant;
  406. }
  407. const ShaderVariant& Shader::GetVariantInternal(ShaderVariantStableId shaderVariantStableId)
  408. {
  409. if (!shaderVariantStableId.IsValid() || shaderVariantStableId == ShaderAsset::RootShaderVariantStableId)
  410. {
  411. return m_rootVariant;
  412. }
  413. {
  414. AZStd::shared_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  415. auto findIt = m_shaderVariants.find(shaderVariantStableId);
  416. if (findIt != m_shaderVariants.end())
  417. {
  418. return findIt->second;
  419. }
  420. }
  421. // By calling GetVariant, an asynchronous asset load request is enqueued if the variant
  422. // is not fully ready.
  423. Data::Asset<ShaderVariantAsset> shaderVariantAsset = m_asset->GetVariantAsset(shaderVariantStableId, m_supervariantIndex);
  424. if (!shaderVariantAsset || shaderVariantAsset == m_asset->GetRootVariantAsset())
  425. {
  426. // Return the root variant when the requested variant is not ready.
  427. return m_rootVariant;
  428. }
  429. AZStd::unique_lock<decltype(m_variantCacheMutex)> lock(m_variantCacheMutex);
  430. // For performance reasons We are breaking this function into two locking steps.
  431. // which means We must check again if the variant is already in the cache.
  432. auto findIt = m_shaderVariants.find(shaderVariantStableId);
  433. if (findIt != m_shaderVariants.end())
  434. {
  435. return findIt->second;
  436. }
  437. ShaderVariant newVariant;
  438. newVariant.Init(m_asset, shaderVariantAsset, m_supervariantIndex);
  439. m_shaderVariants.emplace(shaderVariantStableId, newVariant);
  440. return m_shaderVariants.at(shaderVariantStableId);
  441. }
  442. RHI::PipelineStateType Shader::GetPipelineStateType() const
  443. {
  444. return m_pipelineStateType;
  445. }
  446. const ShaderInputContract& Shader::GetInputContract() const
  447. {
  448. return m_asset->GetInputContract(m_supervariantIndex);
  449. }
  450. const ShaderOutputContract& Shader::GetOutputContract() const
  451. {
  452. return m_asset->GetOutputContract(m_supervariantIndex);
  453. }
  454. const RHI::PipelineState* Shader::AcquirePipelineState(const RHI::PipelineStateDescriptor& descriptor) const
  455. {
  456. return m_pipelineStateCache->AcquirePipelineState(m_pipelineLibraryHandle, descriptor, m_asset->GetName());
  457. }
  458. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindShaderResourceGroupLayout(const Name& shaderResourceGroupName) const
  459. {
  460. return m_asset->FindShaderResourceGroupLayout(shaderResourceGroupName, m_supervariantIndex);
  461. }
  462. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindShaderResourceGroupLayout(uint32_t bindingSlot) const
  463. {
  464. return m_asset->FindShaderResourceGroupLayout(bindingSlot, m_supervariantIndex);
  465. }
  466. const RHI::Ptr<RHI::ShaderResourceGroupLayout>& Shader::FindFallbackShaderResourceGroupLayout() const
  467. {
  468. return m_asset->FindFallbackShaderResourceGroupLayout(m_supervariantIndex);
  469. }
  470. AZStd::span<const RHI::Ptr<RHI::ShaderResourceGroupLayout>> Shader::GetShaderResourceGroupLayouts() const
  471. {
  472. return m_asset->GetShaderResourceGroupLayouts(m_supervariantIndex);
  473. }
  474. Data::Instance<ShaderResourceGroup> Shader::CreateDrawSrgForShaderVariant(const ShaderOptionGroup& shaderOptions, bool compileTheSrg)
  475. {
  476. RHI::Ptr<RHI::ShaderResourceGroupLayout> drawSrgLayout = m_asset->GetDrawSrgLayout(GetSupervariantIndex());
  477. Data::Instance<ShaderResourceGroup> drawSrg;
  478. if (drawSrgLayout)
  479. {
  480. drawSrg = RPI::ShaderResourceGroup::Create(m_asset, GetSupervariantIndex(), drawSrgLayout->GetName());
  481. bool useFallbackKey = !shaderOptions.GetShaderOptionLayout()->IsFullySpecialized() ||
  482. !m_asset->UseSpecializationConstants(GetSupervariantIndex());
  483. if (useFallbackKey && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
  484. {
  485. drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
  486. }
  487. if (compileTheSrg)
  488. {
  489. drawSrg->Compile();
  490. }
  491. }
  492. return drawSrg;
  493. }
  494. Data::Instance<ShaderResourceGroup> Shader::CreateDefaultDrawSrg(bool compileTheSrg)
  495. {
  496. return CreateDrawSrgForShaderVariant(m_asset->GetDefaultShaderOptions(), compileTheSrg);
  497. }
  498. const Data::Asset<ShaderAsset>& Shader::GetAsset() const
  499. {
  500. return m_asset;
  501. }
  502. RHI::DrawListTag Shader::GetDrawListTag() const
  503. {
  504. return m_drawListTag;
  505. }
  506. } // namespace RPI
  507. } // namespace AZ