FeatureVelocity.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Console/IConsole.h>
  9. #include <AzCore/Serialization/EditContext.h>
  10. #include <AzCore/Serialization/SerializeContext.h>
  11. #include <EMotionFX/Source/ActorInstance.h>
  12. #include <Allocators.h>
  13. #include <EMotionFX/Source/EMotionFXManager.h>
  14. #include <EMotionFX/Source/EventManager.h>
  15. #include <EMotionFX/Source/TransformData.h>
  16. #include <FeatureMatrixTransformer.h>
  17. #include <FeatureVelocity.h>
  18. #include <FrameDatabase.h>
  19. #include <MotionMatchingInstance.h>
  20. #include <PoseDataJointVelocities.h>
  21. namespace EMotionFX::MotionMatching
  22. {
  23. AZ_CVAR_EXTERNED(float, mm_debugDrawVelocityScale);
  24. AZ_CLASS_ALLOCATOR_IMPL(FeatureVelocity, MotionMatchAllocator)
  25. void FeatureVelocity::ExtractFeatureValues(const ExtractFeatureContext& context)
  26. {
  27. const ActorInstance* actorInstance = context.m_actorInstance;
  28. const Frame& frame = context.m_frameDatabase->GetFrame(context.m_frameIndex);
  29. AnimGraphPose* tempPose = context.m_posePool.RequestPose(actorInstance);
  30. {
  31. // Calculate the joint velocities for the sampled pose using the same method as we do for the frame database.
  32. PoseDataJointVelocities* velocityPoseData = tempPose->GetPose().GetAndPreparePoseData<PoseDataJointVelocities>(actorInstance);
  33. velocityPoseData->CalculateVelocity(actorInstance,
  34. context.m_posePool,
  35. frame.GetSourceMotion(),
  36. frame.GetSampleTime(),
  37. m_relativeToNodeIndex);
  38. const AZ::Vector3& velocity = velocityPoseData->GetVelocities()[m_jointIndex];
  39. context.m_featureMatrix.SetVector3(context.m_frameIndex, m_featureColumnOffset, velocity);
  40. }
  41. context.m_posePool.FreePose(tempPose);
  42. }
  43. void FeatureVelocity::FillQueryVector(QueryVector& queryVector, const QueryVectorContext& context)
  44. {
  45. PoseDataJointVelocities* velocityPoseData = context.m_currentPose.GetPoseData<PoseDataJointVelocities>();
  46. AZ_Assert(velocityPoseData, "Cannot calculate velocity feature cost without joint velocity pose data.");
  47. const AZ::Vector3 currentVelocity = velocityPoseData->GetVelocity(m_jointIndex);
  48. queryVector.SetVector3(currentVelocity, m_featureColumnOffset);
  49. }
  50. float FeatureVelocity::CalculateFrameCost(size_t frameIndex, const FrameCostContext& context) const
  51. {
  52. const AZ::Vector3 queryVelocity = context.m_queryVector.GetVector3(m_featureColumnOffset);
  53. const AZ::Vector3 frameVelocity = context.m_featureMatrix.GetVector3(frameIndex, m_featureColumnOffset);
  54. return CalcResidual(queryVelocity, frameVelocity);
  55. }
  56. void FeatureVelocity::DebugDraw(AzFramework::DebugDisplayRequests& debugDisplay,
  57. const Pose& pose,
  58. const AZ::Vector3& velocity,
  59. size_t jointIndex,
  60. size_t relativeToJointIndex,
  61. const AZ::Color& color)
  62. {
  63. const Transform jointModelTM = pose.GetModelSpaceTransform(jointIndex);
  64. const Transform relativeToWorldTM = pose.GetWorldSpaceTransform(relativeToJointIndex);
  65. const AZ::Vector3 jointPosition = relativeToWorldTM.TransformPoint(jointModelTM.m_position);
  66. const AZ::Vector3 velocityWorldSpace = relativeToWorldTM.TransformVector(velocity);
  67. DebugDrawVelocity(debugDisplay, jointPosition, velocityWorldSpace * mm_debugDrawVelocityScale, color);
  68. }
  69. void FeatureVelocity::DebugDraw(AzFramework::DebugDisplayRequests& debugDisplay,
  70. const Pose& currentPose,
  71. const FeatureMatrix& featureMatrix,
  72. const FeatureMatrixTransformer* featureTransformer,
  73. size_t frameIndex)
  74. {
  75. if (m_jointIndex == InvalidIndex)
  76. {
  77. return;
  78. }
  79. AZ::Vector3 velocity = featureMatrix.GetVector3(frameIndex, m_featureColumnOffset);
  80. if (featureTransformer)
  81. {
  82. velocity = featureTransformer->InverseTransform(velocity, m_featureColumnOffset);
  83. }
  84. DebugDraw(debugDisplay, currentPose, velocity, m_jointIndex, m_relativeToNodeIndex, m_debugColor);
  85. }
  86. void FeatureVelocity::Reflect(AZ::ReflectContext* context)
  87. {
  88. AZ::SerializeContext* serializeContext = azrtti_cast<AZ::SerializeContext*>(context);
  89. if (!serializeContext)
  90. {
  91. return;
  92. }
  93. serializeContext->Class<FeatureVelocity, Feature>()
  94. ->Version(1)
  95. ;
  96. AZ::EditContext* editContext = serializeContext->GetEditContext();
  97. if (!editContext)
  98. {
  99. return;
  100. }
  101. editContext->Class<FeatureVelocity>("FeatureVelocity", "Matches joint velocities.")
  102. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  103. ->Attribute(AZ::Edit::Attributes::AutoExpand, "")
  104. ;
  105. }
  106. size_t FeatureVelocity::GetNumDimensions() const
  107. {
  108. return 3;
  109. }
  110. AZStd::string FeatureVelocity::GetDimensionName(size_t index) const
  111. {
  112. AZStd::string result = m_jointName;
  113. result += '.';
  114. switch (index)
  115. {
  116. case 0: { result += "VelocityX"; break; }
  117. case 1: { result += "VelocityY"; break; }
  118. case 2: { result += "VelocityZ"; break; }
  119. default: { result += Feature::GetDimensionName(index); }
  120. }
  121. return result;
  122. }
  123. } // namespace EMotionFX::MotionMatching