HairSimulationCompute.azsl 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. /*
  2. * Modifications Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. //---------------------------------------------------------------------------------------
  9. // Shader code related to simulating hair strands in compute.
  10. //-------------------------------------------------------------------------------------
  11. //
  12. // Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved.
  13. //
  14. // Permission is hereby granted, free of charge, to any person obtaining a copy
  15. // of this software and associated documentation files (the "Software"), to deal
  16. // in the Software without restriction, including without limitation the rights
  17. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  18. // copies of the Software, and to permit persons to whom the Software is
  19. // furnished to do so, subject to the following conditions:
  20. //
  21. // The above copyright notice and this permission notice shall be included in
  22. // all copies or substantial portions of the Software.
  23. //
  24. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  25. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  26. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  27. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  28. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  29. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  30. // THE SOFTWARE.
  31. //
  32. //--------------------------------------------------------------------------------------
  33. #include <HairSimulationComputeSrgs.azsli>
  34. #include <HairSimulationCommon.azsli>
  35. //--------------------------------------------------------------------------------------
  36. //
  37. // IntegrationAndGlobalShapeConstraints
  38. //
  39. // Compute shader to simulate the gravitational force with integration and to maintain the
  40. // global shape constraints.
  41. //
  42. // One thread computes one vertex.
  43. //
  44. //--------------------------------------------------------------------------------------
  45. [numthreads(THREAD_GROUP_SIZE, 1, 1)]
  46. void IntegrationAndGlobalShapeConstraints(
  47. uint GIndex : SV_GroupIndex,
  48. uint3 GId : SV_GroupID,
  49. uint3 DTid : SV_DispatchThreadID)
  50. {
  51. uint globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType;
  52. CalcIndicesInVertexLevelMaster(GIndex, GId.x, globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType);
  53. // Copy data from init rest data to be used to set updated shared memory
  54. float4 initialPos = float4(CM_TO_METERS,CM_TO_METERS,CM_TO_METERS,1.0) * g_InitialHairPositions[globalVertexIndex]; // rest position
  55. // Apply bone skinning to initial position
  56. BoneSkinningData skinningData = g_BoneSkinningData[globalStrandIndex];
  57. float4 bone_quat;
  58. initialPos.xyz = ApplyVertexBoneSkinning(initialPos.xyz, skinningData, bone_quat);
  59. // position when this step starts. In other words, a position from the last step.
  60. sharedPos[indexForSharedMem] = GetSharedPosition(globalVertexIndex);
  61. float4 currentPos = sharedPos[indexForSharedMem];
  62. // float4 currentPos = sharedPos[indexForSharedMem] = g_HairVertexPositions[globalVertexIndex];
  63. GroupMemoryBarrierWithGroupSync();
  64. // Integrate
  65. float dampingCoeff = GetDamping(strandType);
  66. float4 oldPos = g_HairVertexPositionsPrev[globalVertexIndex];
  67. // reset if we got teleported
  68. if (g_ResetPositions != 0.0f)
  69. { // Originally part of the data here was NaN as the original TressFX code wrote number
  70. // vertices including the follow hair although the shader accounts for that, hence
  71. // memory was overwriten. In our implementation the memory resides all within
  72. // a single buffer and this would actively overwrite the rest of the buffer hence
  73. // destroying the original contexnt.
  74. currentPos = initialPos;
  75. g_HairVertexPositions[globalVertexIndex] = initialPos;
  76. g_HairVertexPositionsPrev[globalVertexIndex] = initialPos;
  77. g_HairVertexPositionsPrevPrev[globalVertexIndex] = initialPos;
  78. oldPos = initialPos;
  79. }
  80. // skipping all the physics simulation in between
  81. if ( IsMovable(currentPos) )
  82. sharedPos[indexForSharedMem].xyz = Integrate(currentPos.xyz, oldPos.xyz, initialPos.xyz, dampingCoeff);
  83. else
  84. sharedPos[indexForSharedMem] = initialPos;
  85. // Global Shape Constraints
  86. float stiffnessForGlobalShapeMatching = GetGlobalStiffness(strandType);
  87. float globalShapeMatchingEffectiveRange = GetGlobalRange(strandType);
  88. if ( stiffnessForGlobalShapeMatching > 0 && globalShapeMatchingEffectiveRange )
  89. {
  90. if ( IsMovable(sharedPos[indexForSharedMem]) )
  91. {
  92. if ( (float)localVertexIndex < globalShapeMatchingEffectiveRange * (float)numVerticesInTheStrand )
  93. {
  94. float factor = stiffnessForGlobalShapeMatching;
  95. float3 del = factor * (initialPos - sharedPos[indexForSharedMem]).xyz;
  96. sharedPos[indexForSharedMem].xyz += del;
  97. }
  98. }
  99. }
  100. // update global position buffers
  101. UpdateFinalVertexPositions(currentPos, sharedPos[indexForSharedMem], globalVertexIndex);
  102. }
  103. //--------------------------------------------------------------------------------------
  104. //
  105. // Calculate Strand Level Data
  106. //
  107. // Propagate velocity shock resulted by attached based mesh
  108. //
  109. // One thread computes two vertices within a strand.
  110. //
  111. //--------------------------------------------------------------------------------------
  112. [numthreads(THREAD_GROUP_SIZE, 1, 1)]
  113. void CalculateStrandLevelData(
  114. uint GIndex : SV_GroupIndex,
  115. uint3 GId : SV_GroupID,
  116. uint3 DTid : SV_DispatchThreadID)
  117. {
  118. uint local_id, group_id, globalStrandIndex, numVerticesInTheStrand, globalRootVertexIndex, strandType;
  119. CalcIndicesInStrandLevelMaster(GIndex, GId.x, globalStrandIndex, numVerticesInTheStrand, globalRootVertexIndex, strandType);
  120. // Accounting for the right and left side of the strand.
  121. float4 pos_old_old[2]; // previous previous positions for vertex 0 (root) and vertex 1.
  122. float4 pos_old[2]; // previous positions for vertex 0 (root) and vertex 1.
  123. float4 pos_new[2]; // current positions for vertex 0 (root) and vertex 1.
  124. pos_old_old[0] = g_HairVertexPositionsPrevPrev[globalRootVertexIndex];
  125. pos_old_old[1] = g_HairVertexPositionsPrevPrev[globalRootVertexIndex + 1];
  126. pos_old[0] = g_HairVertexPositionsPrev[globalRootVertexIndex];
  127. pos_old[1] = g_HairVertexPositionsPrev[globalRootVertexIndex + 1];
  128. pos_new[0] = g_HairVertexPositions[globalRootVertexIndex];
  129. pos_new[1] = g_HairVertexPositions[globalRootVertexIndex + 1];
  130. float3 u = normalize(pos_old[1].xyz - pos_old[0].xyz);
  131. float3 v = normalize(pos_new[1].xyz - pos_new[0].xyz);
  132. // Compute rotation and translation which transform pos_old to pos_new.
  133. // Since the first two vertices are immovable, we can assume that there is no scaling during tranform.
  134. float4 rot = QuatFromTwoUnitVectors(u, v);
  135. float3 trans = pos_new[0].xyz - MultQuaternionAndVector(rot, pos_old[0].xyz);
  136. float vspCoeff = GetVelocityShockPropogation();
  137. float restLength0 = g_HairRestLengthSRV[globalRootVertexIndex];
  138. float vspAccelThreshold = GetVSPAccelThreshold();
  139. // Increase the VSP coefficient by checking pseudo-acceleration to handle over-stretching when the character moves very fast
  140. float accel = length(pos_new[1] - 2.0 * pos_old[1] + pos_old_old[1]);
  141. if (accel > vspAccelThreshold)
  142. vspCoeff = 1.0f;
  143. g_StrandLevelData[globalStrandIndex].vspQuat = rot;
  144. g_StrandLevelData[globalStrandIndex].vspTranslation = float4(trans, vspCoeff);
  145. // Skinning
  146. // Copy data from init rest data to be used to set updated shared memory
  147. float4 initialPos = float4(CM_TO_METERS,CM_TO_METERS,CM_TO_METERS,1.0) * g_InitialHairPositions[globalRootVertexIndex]; // rest position
  148. // Apply bone skinning to initial position
  149. BoneSkinningData skinningData = g_BoneSkinningData[globalStrandIndex];
  150. float4 bone_quat;
  151. initialPos.xyz = ApplyVertexBoneSkinning(initialPos.xyz, skinningData, bone_quat);
  152. g_StrandLevelData[globalStrandIndex].skinningQuat = bone_quat;
  153. }
  154. //--------------------------------------------------------------------------------------
  155. //
  156. // VelocityShockPropagation
  157. //
  158. // Propagate velocity shock resulted by attached based mesh
  159. //
  160. // One thread computes a vertex in a strand.
  161. //
  162. //--------------------------------------------------------------------------------------
  163. [numthreads(THREAD_GROUP_SIZE, 1, 1)]
  164. void VelocityShockPropagation(
  165. uint GIndex : SV_GroupIndex,
  166. uint3 GId : SV_GroupID,
  167. uint3 DTid : SV_DispatchThreadID)
  168. {
  169. uint globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType;
  170. CalcIndicesInVertexLevelMaster(GIndex, GId.x, globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType);
  171. // The first two vertices are the ones attached to the skin
  172. if (localVertexIndex < 2)
  173. return;
  174. float4 vspQuat = g_StrandLevelData[globalStrandIndex].vspQuat;
  175. float4 vspTrans = g_StrandLevelData[globalStrandIndex].vspTranslation;
  176. float vspCoeff = vspTrans.w;
  177. float4 pos_new_n = g_HairVertexPositions[globalVertexIndex];
  178. float4 pos_old_n = g_HairVertexPositionsPrev[globalVertexIndex];
  179. pos_new_n.xyz = (1.f - vspCoeff) * pos_new_n.xyz + vspCoeff * (MultQuaternionAndVector(vspQuat, pos_new_n.xyz) + vspTrans.xyz);
  180. pos_old_n.xyz = (1.f - vspCoeff) * pos_old_n.xyz + vspCoeff * (MultQuaternionAndVector(vspQuat, pos_old_n.xyz) + vspTrans.xyz);
  181. g_HairVertexPositions[globalVertexIndex].xyz = pos_new_n.xyz;
  182. g_HairVertexPositionsPrev[globalVertexIndex].xyz = pos_old_n.xyz;
  183. }
  184. //--------------------------------------------------------------------------------------
  185. //
  186. // LocalShapeConstraints
  187. //
  188. // Compute shader to maintain the local shape constraints.
  189. //
  190. // One thread computes one strand.
  191. //
  192. //--------------------------------------------------------------------------------------
  193. [numthreads(THREAD_GROUP_SIZE, 1, 1)]
  194. void LocalShapeConstraints(
  195. uint GIndex : SV_GroupIndex,
  196. uint3 GId : SV_GroupID,
  197. uint3 DTid : SV_DispatchThreadID)
  198. {
  199. uint local_id, group_id, globalStrandIndex, numVerticesInTheStrand, globalRootVertexIndex, strandType;
  200. CalcIndicesInStrandLevelMaster(GIndex, GId.x, globalStrandIndex, numVerticesInTheStrand, globalRootVertexIndex, strandType);
  201. // stiffness for local shape constraints
  202. float stiffnessForLocalShapeMatching = GetLocalStiffness(strandType);
  203. // Going beyond the TH will create less stability in convergence
  204. const float stabilityTH = 0.95f;
  205. stiffnessForLocalShapeMatching = 0.5f * min(stiffnessForLocalShapeMatching, stabilityTH);
  206. //--------------------------------------------
  207. // Local shape constraint for bending/twisting
  208. //--------------------------------------------
  209. {
  210. float4 boneQuat = g_StrandLevelData[globalStrandIndex].skinningQuat;
  211. // vertex 1 through n-1
  212. for (uint localVertexIndex = 1; localVertexIndex < numVerticesInTheStrand - 1; localVertexIndex++)
  213. {
  214. uint globalVertexIndex = globalRootVertexIndex + localVertexIndex;
  215. float4 pos = g_HairVertexPositions[globalVertexIndex];
  216. float4 pos_plus_one = g_HairVertexPositions[globalVertexIndex + 1];
  217. float4 pos_minus_one = g_HairVertexPositions[globalVertexIndex - 1];
  218. float3 bindPos = MultQuaternionAndVector(boneQuat, g_InitialHairPositions[globalVertexIndex].xyz * CM_TO_METERS);
  219. float3 bindPos_plus_one = MultQuaternionAndVector(boneQuat, g_InitialHairPositions[globalVertexIndex + 1].xyz * CM_TO_METERS);
  220. float3 bindPos_minus_one = MultQuaternionAndVector(boneQuat, g_InitialHairPositions[globalVertexIndex - 1].xyz * CM_TO_METERS);
  221. float3 lastVec = pos.xyz - pos_minus_one.xyz;
  222. float3 vecBindPose = bindPos_plus_one - bindPos;
  223. float3 lastVecBindPose = bindPos - bindPos_minus_one;
  224. float4 rotGlobal = QuatFromTwoUnitVectors(normalize(lastVecBindPose), normalize(lastVec));
  225. float3 orgPos_i_plus_1_InGlobalFrame = MultQuaternionAndVector(rotGlobal, vecBindPose) + pos.xyz;
  226. float3 del = stiffnessForLocalShapeMatching * (orgPos_i_plus_1_InGlobalFrame - pos_plus_one.xyz);
  227. if (IsMovable(pos))
  228. pos.xyz -= del.xyz;
  229. if (IsMovable(pos_plus_one))
  230. pos_plus_one.xyz += del.xyz;
  231. g_HairVertexPositions[globalVertexIndex].xyz = pos.xyz;
  232. g_HairVertexPositions[globalVertexIndex + 1].xyz = pos_plus_one.xyz;
  233. }
  234. }
  235. }
  236. //--------------------------------------------------------------------------------------
  237. //
  238. // LengthConstriantsWindAndCollision
  239. //
  240. // Compute shader to move the vertex position based on wind, maintain the lenght constraints
  241. // and handles collisions.
  242. //
  243. // One thread computes one vertex.
  244. //
  245. //--------------------------------------------------------------------------------------
  246. [numthreads(THREAD_GROUP_SIZE, 1, 1)]
  247. void LengthConstriantsWindAndCollision(uint GIndex : SV_GroupIndex,
  248. uint3 GId : SV_GroupID,
  249. uint3 DTid : SV_DispatchThreadID)
  250. {
  251. uint globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType;
  252. CalcIndicesInVertexLevelMaster(GIndex, GId.x, globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType);
  253. uint numOfStrandsPerThreadGroup = g_NumOfStrandsPerThreadGroup;
  254. //------------------------------
  255. // Copy data into shared memory
  256. //------------------------------
  257. sharedPos[indexForSharedMem] = g_HairVertexPositions[globalVertexIndex];
  258. sharedLength[indexForSharedMem] = g_HairRestLengthSRV[globalVertexIndex] * CM_TO_METERS;
  259. GroupMemoryBarrierWithGroupSync();
  260. /*
  261. //------------
  262. // Wind - does not work yet and requires some LTC
  263. //------------
  264. if (any(g_Wind.xyz)) // g_Wind.w is the current frame
  265. {
  266. float4 force = float4(0, 0, 0, 0);
  267. if ( localVertexIndex >= 2 && localVertexIndex < numVerticesInTheStrand-1 )
  268. {
  269. // combining four winds.
  270. float a = ((float)(globalStrandIndex % 20))/20.0f;
  271. float3 w = a* g_Wind.xyz + (1.0f - a) * g_Wind1.xyz + a * g_Wind2.xyz + (1.0f - a) * g_Wind3.xyz;
  272. // float3 w = float3(5.2, 0, 0);
  273. uint sharedIndex = localVertexIndex * numOfStrandsPerThreadGroup + localStrandIndex;
  274. float3 v = sharedPos[sharedIndex].xyz - sharedPos[sharedIndex+numOfStrandsPerThreadGroup].xyz;
  275. float3 force = -cross(cross(v, w), v);
  276. sharedPos[sharedIndex].xyz += force*g_TimeStep*g_TimeStep;
  277. }
  278. }
  279. GroupMemoryBarrierWithGroupSync();
  280. */
  281. //----------------------------
  282. // Enforce length constraints
  283. //----------------------------
  284. uint a = numVerticesInTheStrand/2.0f;
  285. uint b = (numVerticesInTheStrand-1)/2.0f;
  286. int lengthContraintIterations = GetLengthConstraintIterations();
  287. for ( int iterationE=0; iterationE < lengthContraintIterations; iterationE++ )
  288. {
  289. uint sharedIndex = 2 * localVertexIndex * numOfStrandsPerThreadGroup + localStrandIndex;
  290. // Notice that the following is based on the fact that we are dealing here with two vertices
  291. // one at each side of the central control point and each should extend towards its side only.
  292. if( localVertexIndex < a )
  293. ApplyDistanceConstraint(sharedPos[sharedIndex], sharedPos[sharedIndex+numOfStrandsPerThreadGroup], sharedLength[sharedIndex].x);
  294. GroupMemoryBarrierWithGroupSync();
  295. if( localVertexIndex < b )
  296. ApplyDistanceConstraint(sharedPos[sharedIndex+numOfStrandsPerThreadGroup], sharedPos[sharedIndex+numOfStrandsPerThreadGroup*2], sharedLength[sharedIndex+numOfStrandsPerThreadGroup].x);
  297. GroupMemoryBarrierWithGroupSync();
  298. }
  299. //------------------------------------------
  300. // Collision handling with capsule objects
  301. //------------------------------------------
  302. float4 oldPos = g_HairVertexPositionsPrev[globalVertexIndex];
  303. bool bAnyColDetected = false; // Adi
  304. // bool bAnyColDetected = ResolveCapsuleCollisions(sharedPos[indexForSharedMem], oldPos);
  305. GroupMemoryBarrierWithGroupSync();
  306. //-------------------
  307. // Compute tangent
  308. //-------------------
  309. // If this is the last vertex in the strand, we can't get tangent from subtracting from the next vertex, need to use last vertex to current
  310. uint indexForTangent = (localVertexIndex == numVerticesInTheStrand - 1) ? indexForSharedMem - numOfStrandsPerThreadGroup : indexForSharedMem;
  311. float3 tangent = sharedPos[indexForTangent + numOfStrandsPerThreadGroup].xyz - sharedPos[indexForTangent].xyz;
  312. g_HairVertexTangents[globalVertexIndex].xyz = normalize(tangent);
  313. //---------------------------------------
  314. // clamp velocities, rewrite history
  315. //---------------------------------------
  316. float3 positionDelta = sharedPos[indexForSharedMem].xyz - oldPos;
  317. float speedSqr = dot(positionDelta, positionDelta);
  318. if (speedSqr > g_ClampPositionDelta * g_ClampPositionDelta) {
  319. positionDelta *= g_ClampPositionDelta * g_ClampPositionDelta / speedSqr;
  320. g_HairVertexPositionsPrev[globalVertexIndex].xyz = sharedPos[indexForSharedMem].xyz - positionDelta;
  321. }
  322. //---------------------------------------
  323. // update global position buffers
  324. //---------------------------------------
  325. g_HairVertexPositions[globalVertexIndex] = sharedPos[indexForSharedMem];
  326. if (bAnyColDetected)
  327. g_HairVertexPositionsPrev[globalVertexIndex] = sharedPos[indexForSharedMem];
  328. return;
  329. }
  330. //--------------------------------------------------------------------------------------
  331. //
  332. // UpdateFollowHairVertices
  333. //
  334. // Last stage update of the follow hair to follow their guide hair
  335. //
  336. // One thread computes one vertex.
  337. //
  338. //--------------------------------------------------------------------------------------
  339. [numthreads(THREAD_GROUP_SIZE, 1, 1)]
  340. void UpdateFollowHairVertices(
  341. uint GIndex : SV_GroupIndex,
  342. uint3 GId : SV_GroupID,
  343. uint3 DTid : SV_DispatchThreadID)
  344. {
  345. uint globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType;
  346. CalcIndicesInVertexLevelMaster(GIndex, GId.x, globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem, strandType);
  347. sharedPos[indexForSharedMem] = GetSharedPosition(globalVertexIndex); // g_HairVertexPositions[globalVertexIndex];
  348. sharedTangent[indexForSharedMem].xyz = GetSharedTangent(globalVertexIndex); // g_HairVertexTangents[globalVertexIndex];
  349. GroupMemoryBarrierWithGroupSync();
  350. for ( uint i = 0; i < g_NumFollowHairsPerGuideHair; i++ )
  351. {
  352. int globalFollowVertexIndex = globalVertexIndex + numVerticesInTheStrand * (i + 1);
  353. int globalFollowStrandIndex = globalStrandIndex + i + 1;
  354. float factor = g_TipSeparationFactor*((float)localVertexIndex / (float)numVerticesInTheStrand) + 1.0f;
  355. float3 followPos = sharedPos[indexForSharedMem].xyz + factor * CM_TO_METERS * g_FollowHairRootOffset[globalFollowStrandIndex].xyz;
  356. SetSharedPosition3(globalFollowVertexIndex, followPos);
  357. // g_HairVertexPositions[globalFollowVertexIndex].xyz = followPos;
  358. //-----------------------
  359. // SetSharedTangent(globalFollowVertexIndex, sharedTangent[indexForSharedMem]);
  360. g_HairVertexTangents[globalFollowVertexIndex] = sharedTangent[indexForSharedMem];
  361. }
  362. return;
  363. }