BlendTreeMaskNodeTests.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Tests/AnimGraphFixture.h>
  9. #include <EMotionFX/Source/Actor.h>
  10. #include <EMotionFX/Source/AnimGraph.h>
  11. #include <EMotionFX/Source/AnimGraphStateMachine.h>
  12. #include <EMotionFX/Source/AnimGraphMotionNode.h>
  13. #include <EMotionFX/Source/BlendTree.h>
  14. #include <EMotionFX/Source/BlendTreeMaskNode.h>
  15. #include <EMotionFX/Source/EMotionFXManager.h>
  16. #include <EMotionFX/Source/Node.h>
  17. #include <EMotionFX/Source/TransformData.h>
  18. #include <Tests/TestAssetCode/SimpleActors.h>
  19. #include <Tests/TestAssetCode/ActorFactory.h>
  20. namespace EMotionFX
  21. {
  22. class BlendTreeTestInputNode
  23. : public AnimGraphNode
  24. {
  25. public:
  26. AZ_CLASS_ALLOCATOR(BlendTreeTestInputNode, AnimGraphAllocator)
  27. AZ_RTTI(BlendTreeTestInputNode, "{72595B5C-045C-4DB1-88A4-40BC4560D7AF}", AnimGraphNode)
  28. enum
  29. {
  30. OUTPUTPORT_RESULT = 0
  31. };
  32. BlendTreeTestInputNode(float value)
  33. : AnimGraphNode()
  34. , m_identificationValue(value)
  35. {
  36. InitOutputPorts(1);
  37. SetupOutputPortAsPose("Output Pose", OUTPUTPORT_RESULT, OUTPUTPORT_RESULT);
  38. }
  39. AZ::Color GetVisualColor() const override { return AZ::Color(1.0f, 1.0f, 0.0f, 1.0f); }
  40. bool GetHasOutputPose() const override { return true; }
  41. const char* GetPaletteName() const override { return "BlendTreeTestInputNode"; }
  42. AnimGraphObject::ECategory GetPaletteCategory() const override { return AnimGraphObject::CATEGORY_SOURCES; }
  43. AnimGraphPose* GetMainOutputPose(AnimGraphInstance* animGraphInstance) const override { return GetOutputPose(animGraphInstance, OUTPUTPORT_RESULT)->GetValue(); }
  44. bool InitAfterLoading(AnimGraph* animGraph) override
  45. {
  46. if (!AnimGraphNode::InitAfterLoading(animGraph))
  47. {
  48. return false;
  49. }
  50. InitInternalAttributesForAllInstances();
  51. Reinit();
  52. return true;
  53. }
  54. void Output(AnimGraphInstance* animGraphInstance) override
  55. {
  56. RequestPoses(animGraphInstance);
  57. AnimGraphPose* outputAnimGraphPose = GetOutputPose(animGraphInstance, OUTPUTPORT_RESULT)->GetValue();
  58. outputAnimGraphPose->InitFromBindPose(animGraphInstance->GetActorInstance());
  59. Pose& outputPose = outputAnimGraphPose->GetPose();
  60. // Output the assigned value of the node for each joint so that we can identify from which input each joint is coming from.
  61. const size_t numJoints = outputPose.GetNumTransforms();
  62. for (size_t i = 0; i < numJoints; ++i)
  63. {
  64. Transform transform = outputPose.GetLocalSpaceTransform(i);
  65. transform.m_position = AZ::Vector3(m_identificationValue, m_identificationValue, m_identificationValue);
  66. outputPose.SetLocalSpaceTransform(i, transform);
  67. }
  68. }
  69. private:
  70. float m_identificationValue;
  71. };
  72. using MaskNodeTestParam = std::vector<std::vector<std::string>>;
  73. /*
  74. * The general idea is to identify the origin of the joints by embedding identification values into the joint transform
  75. * and inside the test extract that value and thus know from which mask input it belongs to.
  76. * We create a blend tree with a mask node having several input nodes. The first one representing the base pose and three
  77. * input mask nodes with a customizable mask which comes in by the test parameter.
  78. * We run several tests with different variations of masks and check if the output transforms for each joint corresponds with
  79. * the set masks and if the mask node picked and overwrote the correct transforms.
  80. */
  81. class BlendTreeMaskNodeTestFixture
  82. : public AnimGraphFixture
  83. , public ::testing::WithParamInterface<MaskNodeTestParam>
  84. {
  85. public:
  86. void ConstructActor() override
  87. {
  88. m_actor = ActorFactory::CreateAndInit<AllRootJointsActor>(5);
  89. }
  90. AZStd::vector<AZStd::string> ConstructMask(const std::vector<std::string>& in)
  91. {
  92. AZStd::vector<AZStd::string> result;
  93. result.reserve(in.size());
  94. for (const std::string& str : in)
  95. {
  96. result.emplace_back(AZStd::string(str.c_str(), str.size()));
  97. }
  98. return result;
  99. }
  100. AZ::Outcome<size_t> FindMaskIndexForJoint(size_t jointIndex) const
  101. {
  102. const MaskNodeTestParam& param = GetParam();
  103. Skeleton* skeleton = m_actor->GetSkeleton();
  104. const size_t numMasks = param.size();
  105. for (size_t maskIndex = 0; maskIndex < numMasks; ++maskIndex)
  106. {
  107. const std::vector<std::string>& mask = param[maskIndex];
  108. const Node* joint = skeleton->GetNode(jointIndex);
  109. const char* jointName = joint->GetName();
  110. // Is joint in the current mask? Return the index in this case.
  111. if (std::find(mask.begin(), mask.end(), jointName) != mask.end())
  112. {
  113. return AZ::Success(maskIndex);
  114. }
  115. }
  116. return AZ::Failure();
  117. }
  118. void ConstructGraph() override
  119. {
  120. AnimGraphFixture::ConstructGraph();
  121. const MaskNodeTestParam& param = GetParam();
  122. m_blendTreeAnimGraph = AnimGraphFactory::Create<OneBlendTreeNodeAnimGraph>();
  123. m_rootStateMachine = m_blendTreeAnimGraph->GetRootStateMachine();
  124. m_blendTree = m_blendTreeAnimGraph->GetBlendTreeNode();
  125. /*
  126. +-----------+
  127. | Base Pose +----------+
  128. +-----------+ |
  129. |
  130. +----------+ >+-----------+ +-------+
  131. | Mask 0 +----------->| Pose Mask +-------------->+ Final |
  132. +----------+ ------>| | +-------+
  133. | >+-----------+
  134. +----------+ | |
  135. | Mask 1 +-----+ |
  136. +----------+ |
  137. |
  138. +-------------+ |
  139. | Mask 3 +--------+
  140. +-------------+
  141. */
  142. m_maskNode = aznew BlendTreeMaskNode();
  143. m_blendTree->AddChildNode(m_maskNode);
  144. BlendTreeFinalNode* finalNode = aznew BlendTreeFinalNode();
  145. m_blendTree->AddChildNode(finalNode);
  146. finalNode->AddConnection(m_maskNode, BlendTreeMaskNode::OUTPUTPORT_RESULT, BlendTreeFinalNode::PORTID_INPUT_POSE);
  147. m_basePoseNode = aznew BlendTreeTestInputNode(static_cast<float>(m_basePosePosValue));
  148. m_blendTree->AddChildNode(m_basePoseNode);
  149. m_maskNode->AddConnection(m_basePoseNode, BlendTreeTestInputNode::OUTPUTPORT_RESULT, BlendTreeMaskNode::INPUTPORT_BASEPOSE);
  150. for (uint16 i = 0; i < m_numMaskInputNodes; ++i)
  151. {
  152. BlendTreeTestInputNode* inputNode = aznew BlendTreeTestInputNode(static_cast<float>(i));
  153. m_blendTree->AddChildNode(inputNode);
  154. m_maskNode->AddConnection(inputNode, BlendTreeTestInputNode::OUTPUTPORT_RESULT, BlendTreeMaskNode::INPUTPORT_START + i);
  155. m_maskInputNodes.push_back(inputNode);
  156. }
  157. const size_t numMasks = param.size();
  158. ASSERT_EQ(numMasks, m_numMaskInputNodes)
  159. << "The number of provides masks in the parameter (" << numMasks << ") should match the number of created "
  160. << "input mask nodes (" << m_numMaskInputNodes << ").";
  161. for (size_t i = 0; i < numMasks; ++i)
  162. {
  163. m_maskNode->SetMask(i, ConstructMask(param[i]));
  164. }
  165. m_blendTreeAnimGraph->InitAfterLoading();
  166. }
  167. void SetUp() override
  168. {
  169. AnimGraphFixture::SetUp();
  170. m_animGraphInstance->Destroy();
  171. m_animGraphInstance = m_blendTreeAnimGraph->GetAnimGraphInstance(m_actorInstance, m_motionSet);
  172. }
  173. public:
  174. BlendTreeMaskNode* m_maskNode = nullptr;
  175. BlendTreeTestInputNode* m_basePoseNode = nullptr;
  176. const size_t m_basePosePosValue = 100; // Special identification value for the base pose to easily distinguish it from the mask indices.
  177. std::vector<BlendTreeTestInputNode*> m_maskInputNodes;
  178. size_t m_numMaskInputNodes = 3;
  179. BlendTree* m_blendTree = nullptr;
  180. };
  181. TEST_P(BlendTreeMaskNodeTestFixture, MaskTests)
  182. {
  183. GetEMotionFX().Update(0.0f);
  184. Skeleton* skeleton = m_actor->GetSkeleton();
  185. const size_t numJoints = skeleton->GetNumNodes();
  186. TransformData* transformData = m_actorInstance->GetTransformData();
  187. Pose* pose = transformData->GetCurrentPose();
  188. // Iterate through the joints and make sure their transforms originate according to the mask setup.
  189. for (size_t jointIndex = 0; jointIndex < numJoints; jointIndex++)
  190. {
  191. const Node* joint = skeleton->GetNode(jointIndex);
  192. const char* jointName = joint->GetName();
  193. const Transform& transform = pose->GetModelSpaceTransform(jointIndex);
  194. // The components of the position embed the origin.
  195. // If the compareValue equals m_basePosePosValue, it originates from the base pose input.
  196. // In case the joint is part of any of the masks and got overwriten by them, the compareValue represents the mask index.
  197. const size_t compareValue = static_cast<size_t>(transform.m_position.GetX());
  198. AZ::Outcome<size_t> maskIndex = FindMaskIndexForJoint(jointIndex);
  199. if (maskIndex.IsSuccess())
  200. {
  201. EXPECT_EQ(compareValue, maskIndex.GetValue())
  202. << "Joint '" << jointName << "' is part of mask " << maskIndex.GetValue()
  203. << " while the transform originated from input number " << compareValue
  204. << ".";
  205. }
  206. else
  207. {
  208. EXPECT_EQ(compareValue, m_basePosePosValue)
  209. << "Joint '" << jointName << "' is not part of any mask while the transform "
  210. << "originated from input number " << compareValue << ". It should originate "
  211. << "from the base pose input.";
  212. }
  213. }
  214. }
  215. std::vector<MaskNodeTestParam> maskNodeTestData
  216. {
  217. {
  218. {},
  219. {},
  220. {},
  221. },
  222. {
  223. { "rootJoint" },
  224. {},
  225. {},
  226. },
  227. {
  228. { "rootJoint", "joint2" },
  229. {},
  230. {},
  231. },
  232. {
  233. { "rootJoint", "joint1", "joint2" },
  234. {},
  235. {},
  236. },
  237. {
  238. { "rootJoint", "joint1", "joint2", "joint3", "joint4" },
  239. {},
  240. {},
  241. },
  242. {
  243. {},
  244. { "joint1", "joint3" },
  245. {},
  246. },
  247. {
  248. {},
  249. {},
  250. { "joint2", "joint4" },
  251. },
  252. {
  253. { "rootJoint", "joint1" },
  254. { "joint3", "joint4" },
  255. {},
  256. },
  257. {
  258. { "rootJoint", "joint1" },
  259. {},
  260. { "joint3", "joint4" },
  261. },
  262. {
  263. {},
  264. { "rootJoint", "joint1" },
  265. { "joint3", "joint4" },
  266. },
  267. {
  268. { "rootJoint" },
  269. { "joint1", "joint2" },
  270. { "joint3", "joint4" },
  271. },
  272. };
  273. INSTANTIATE_TEST_SUITE_P(BlendTreeMaskNode,
  274. BlendTreeMaskNodeTestFixture,
  275. ::testing::ValuesIn(maskNodeTestData));
  276. } // namespace EMotionFX