KdTree.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/std/algorithm.h>
  9. #include <AzCore/Debug/Timer.h>
  10. #include <KdTree.h>
  11. #include <Feature.h>
  12. #include <Allocators.h>
  13. namespace EMotionFX::MotionMatching
  14. {
  15. AZ_CLASS_ALLOCATOR_IMPL(KdTree, MotionMatchAllocator);
  16. AZ_CLASS_ALLOCATOR_IMPL(KdTree::Node, MotionMatchAllocator);
  17. KdTree::~KdTree()
  18. {
  19. Clear();
  20. }
  21. size_t KdTree::CalcNumDimensions(const AZStd::vector<Feature*>& features)
  22. {
  23. size_t result = 0;
  24. for (Feature* feature : features)
  25. {
  26. if (feature->GetId().IsNull())
  27. {
  28. continue;
  29. }
  30. result += feature->GetNumDimensions();
  31. }
  32. return result;
  33. }
  34. bool KdTree::Init(const FrameDatabase& frameDatabase,
  35. const FeatureMatrix& featureMatrix,
  36. const AZStd::vector<Feature*>& features,
  37. size_t maxDepth,
  38. size_t minFramesPerLeaf)
  39. {
  40. AZ_PROFILE_SCOPE(Animation, "MotionMatchingData::InitKdTree");
  41. #if !defined(_RELEASE)
  42. AZ::Debug::Timer timer;
  43. timer.Stamp();
  44. #endif
  45. Clear();
  46. // Verify the dimensions.
  47. m_numDimensions = CalcNumDimensions(features);
  48. // Going above a 48 dimensional tree would start eating up too much memory.
  49. const size_t maxNumDimensions = 48;
  50. if (m_numDimensions == 0 || m_numDimensions > maxNumDimensions)
  51. {
  52. AZ_Error("Motion Matching", false, "Cannot initialize KD-tree. KD-tree dimension (%d) has to be between 1 and %zu. Please use Feature::SetIncludeInKdTree(false) on some features.", m_numDimensions, maxNumDimensions);
  53. return false;
  54. }
  55. if (minFramesPerLeaf > 100000)
  56. {
  57. AZ_Error("Motion Matching", false, "KdTree minFramesPerLeaf (%d) cannot be bigger than 100000.", minFramesPerLeaf);
  58. return false;
  59. }
  60. if (maxDepth == 0)
  61. {
  62. AZ_Error("Motion Matching", false, "KdTree max depth (%d) cannot be zero", maxDepth);
  63. return false;
  64. }
  65. if (frameDatabase.GetNumFrames() == 0)
  66. {
  67. AZ_Error("Motion Matching", false, "Skipping to initialize KD-tree. No frames in the motion database.");
  68. return true;
  69. }
  70. m_maxDepth = maxDepth;
  71. m_minFramesPerLeaf = minFramesPerLeaf;
  72. // Not all features are present in the KD-tree, thus we need to remap KD-tree local feature columns to the
  73. // feature schema global feature columns.
  74. const AZStd::vector<size_t> localToSchemaFeatureColumns = CalcLocalToSchemaFeatureColumns(features);
  75. // Build the tree.
  76. BuildTreeNodes(frameDatabase, featureMatrix, localToSchemaFeatureColumns, aznew Node(), nullptr, 0);
  77. MergeSmallLeafNodesToParents();
  78. ClearFramesForNonEssentialNodes();
  79. RemoveZeroFrameLeafNodes();
  80. #if !defined(_RELEASE)
  81. const float initTime = timer.GetDeltaTimeInSeconds();
  82. AZ_TracePrintf("Motion Matching", "KD-Tree initialized in %.2f ms (numNodes = %d numDims = %d Memory used = %.2f MB).",
  83. initTime * 1000.0f,
  84. m_nodes.size(),
  85. m_numDimensions,
  86. static_cast<float>(CalcMemoryUsageInBytes()) / 1024.0f / 1024.0f);
  87. PrintStats();
  88. #endif
  89. return true;
  90. }
  91. AZStd::vector<size_t> KdTree::CalcLocalToSchemaFeatureColumns(const AZStd::vector<Feature*>& features) const
  92. {
  93. AZStd::vector<size_t> localToSchemaFeatureColumns;
  94. localToSchemaFeatureColumns.resize(m_numDimensions);
  95. size_t currentColumn = 0;
  96. for (const Feature* feature : features)
  97. {
  98. const size_t numDimensions = feature->GetNumDimensions();
  99. const size_t featureColumnOffset = feature->GetColumnOffset();
  100. for (size_t i = 0; i < numDimensions; ++i)
  101. {
  102. localToSchemaFeatureColumns[currentColumn] = featureColumnOffset + i;
  103. currentColumn++;
  104. }
  105. }
  106. AZ_Assert(m_numDimensions == currentColumn, "There should be a column index mapping for each of the available dimensions.");
  107. return localToSchemaFeatureColumns;
  108. }
  109. void KdTree::Clear()
  110. {
  111. // delete all nodes
  112. for (Node* node : m_nodes)
  113. {
  114. delete node;
  115. }
  116. m_nodes.clear();
  117. m_numDimensions = 0;
  118. }
  119. size_t KdTree::CalcMemoryUsageInBytes() const
  120. {
  121. size_t totalBytes = 0;
  122. for (const Node* node : m_nodes)
  123. {
  124. totalBytes += sizeof(Node);
  125. totalBytes += node->m_frames.capacity() * sizeof(size_t);
  126. }
  127. totalBytes += sizeof(KdTree);
  128. return totalBytes;
  129. }
  130. bool KdTree::IsInitialized() const
  131. {
  132. return (m_numDimensions != 0);
  133. }
  134. size_t KdTree::GetNumNodes() const
  135. {
  136. return m_nodes.size();
  137. }
  138. size_t KdTree::GetNumDimensions() const
  139. {
  140. return m_numDimensions;
  141. }
  142. void KdTree::BuildTreeNodes(const FrameDatabase& frameDatabase,
  143. const FeatureMatrix& featureMatrix,
  144. const AZStd::vector<size_t>& localToSchemaFeatureColumns,
  145. Node* node,
  146. Node* parent,
  147. size_t dimension,
  148. bool leftSide)
  149. {
  150. node->m_parent = parent;
  151. node->m_dimension = dimension;
  152. m_nodes.emplace_back(node);
  153. // Fill the frames array and calculate the median.
  154. AZStd::vector<float> frameFeatureValues;
  155. FillFramesForNode(node, frameDatabase, featureMatrix, localToSchemaFeatureColumns, frameFeatureValues, parent, leftSide);
  156. // Prevent splitting further when we don't want to.
  157. const size_t maxDimensions = AZ::GetMin(m_numDimensions, m_maxDepth);
  158. if (node->m_frames.size() < m_minFramesPerLeaf * 2 ||
  159. dimension >= maxDimensions)
  160. {
  161. return;
  162. }
  163. // Create the left node.
  164. Node* leftNode = aznew Node();
  165. AZ_Assert(!node->m_leftNode, "Expected the parent left node to be a nullptr");
  166. node->m_leftNode = leftNode;
  167. BuildTreeNodes(frameDatabase, featureMatrix, localToSchemaFeatureColumns, leftNode, node, dimension + 1, true);
  168. // Create the right node.
  169. Node* rightNode = aznew Node();
  170. AZ_Assert(!node->m_rightNode, "Expected the parent right node to be a nullptr");
  171. node->m_rightNode = rightNode;
  172. BuildTreeNodes(frameDatabase, featureMatrix, localToSchemaFeatureColumns, rightNode, node, dimension + 1, false);
  173. }
  174. void KdTree::ClearFramesForNonEssentialNodes()
  175. {
  176. for (Node* node : m_nodes)
  177. {
  178. if (node->m_leftNode && node->m_rightNode)
  179. {
  180. node->m_frames.clear();
  181. node->m_frames.shrink_to_fit();
  182. }
  183. }
  184. }
  185. void KdTree::RemoveLeafNode(Node* node)
  186. {
  187. Node* parent = node->m_parent;
  188. if (parent->m_leftNode == node)
  189. {
  190. parent->m_leftNode = nullptr;
  191. }
  192. if (parent->m_rightNode == node)
  193. {
  194. parent->m_rightNode = nullptr;
  195. }
  196. // Remove it from the node vector.
  197. const auto location = AZStd::find(m_nodes.begin(), m_nodes.end(), node);
  198. AZ_Assert(location != m_nodes.end(), "Expected to find the item to remove.");
  199. m_nodes.erase(location);
  200. delete node;
  201. }
  202. void KdTree::MergeSmallLeafNodesToParents()
  203. {
  204. // If the tree is empty or only has a single node, there is nothing to merge.
  205. if (m_nodes.size() < 2)
  206. {
  207. return;
  208. }
  209. AZStd::vector<Node*> nodesToRemove;
  210. for (Node* node : m_nodes)
  211. {
  212. // If we are a leaf node and we don't have enough frames.
  213. if ((!node->m_leftNode && !node->m_rightNode) &&
  214. node->m_frames.size() < m_minFramesPerLeaf)
  215. {
  216. nodesToRemove.emplace_back(node);
  217. }
  218. }
  219. // Remove the actual nodes.
  220. for (Node* node : nodesToRemove)
  221. {
  222. RemoveLeafNode(node);
  223. }
  224. }
  225. void KdTree::RemoveZeroFrameLeafNodes()
  226. {
  227. AZStd::vector<Node*> nodesToRemove;
  228. // Build a list of leaf nodes to remove.
  229. // These are ones that have no feature inside them.
  230. for (Node* node : m_nodes)
  231. {
  232. if ((!node->m_leftNode && !node->m_rightNode) &&
  233. node->m_frames.empty())
  234. {
  235. nodesToRemove.emplace_back(node);
  236. }
  237. }
  238. // Remove the actual nodes.
  239. for (Node* node : nodesToRemove)
  240. {
  241. RemoveLeafNode(node);
  242. }
  243. }
  244. void KdTree::FillFramesForNode(Node* node,
  245. const FrameDatabase& frameDatabase,
  246. const FeatureMatrix& featureMatrix,
  247. const AZStd::vector<size_t>& localToSchemaFeatureColumns,
  248. AZStd::vector<float>& frameFeatureValues,
  249. Node* parent,
  250. bool leftSide)
  251. {
  252. frameFeatureValues.clear();
  253. if (parent)
  254. {
  255. // Assume half of the parent frames are in this node.
  256. const size_t numExpectedFrames = (parent->m_frames.size() / 2) + 1;
  257. frameFeatureValues.reserve(numExpectedFrames);
  258. node->m_frames.reserve(numExpectedFrames);
  259. // Add parent frames to this node, but only ones that should be on this side.
  260. for (const size_t frameIndex : parent->m_frames)
  261. {
  262. // Remap local to the KD-tree feature column to the feature schema global column and read the value directly from the feature matrix.
  263. const float featureValue = featureMatrix(frameIndex, localToSchemaFeatureColumns[parent->m_dimension]);
  264. frameFeatureValues.push_back(featureValue);
  265. if (leftSide)
  266. {
  267. if (featureValue <= parent->m_median)
  268. {
  269. node->m_frames.emplace_back(frameIndex);
  270. }
  271. }
  272. else
  273. {
  274. if (featureValue > parent->m_median)
  275. {
  276. node->m_frames.emplace_back(frameIndex);
  277. }
  278. }
  279. }
  280. }
  281. else // We're the root node.
  282. {
  283. node->m_frames.reserve(frameDatabase.GetNumFrames());
  284. for (const Frame& frame : frameDatabase.GetFrames())
  285. {
  286. const size_t frameIndex = frame.GetFrameIndex();
  287. node->m_frames.emplace_back(frameIndex);
  288. // Remap local to the KD-tree feature column to the feature schema global column and read the value directly from the feature matrix.
  289. const float featureValue = featureMatrix(frameIndex, localToSchemaFeatureColumns[node->m_dimension]);
  290. frameFeatureValues.push_back(featureValue);
  291. }
  292. }
  293. // Calculate the median in O(n).
  294. node->m_median = 0.0f;
  295. if (!frameFeatureValues.empty())
  296. {
  297. auto medianIterator = frameFeatureValues.begin() + frameFeatureValues.size() / 2;
  298. AZStd::nth_element(frameFeatureValues.begin(), medianIterator, frameFeatureValues.end());
  299. node->m_median = frameFeatureValues[frameFeatureValues.size() / 2];
  300. }
  301. }
  302. void KdTree::RecursiveCalcNumFrames(Node* node, size_t& outNumFrames) const
  303. {
  304. if (node->m_leftNode && node->m_rightNode)
  305. {
  306. RecursiveCalcNumFrames(node->m_leftNode, outNumFrames);
  307. RecursiveCalcNumFrames(node->m_rightNode, outNumFrames);
  308. }
  309. else
  310. {
  311. outNumFrames += node->m_frames.size();
  312. }
  313. }
  314. void KdTree::PrintStats()
  315. {
  316. #if !defined(_RELEASE)
  317. size_t leftNumFrames = 0;
  318. size_t rightNumFrames = 0;
  319. if (m_nodes[0]->m_leftNode)
  320. {
  321. RecursiveCalcNumFrames(m_nodes[0]->m_leftNode, leftNumFrames);
  322. }
  323. if (m_nodes[0]->m_rightNode)
  324. {
  325. RecursiveCalcNumFrames(m_nodes[0]->m_rightNode, rightNumFrames);
  326. }
  327. const float numFrames = static_cast<float>(leftNumFrames + rightNumFrames);
  328. const float halfFrames = numFrames / 2.0f;
  329. const float balanceScore = 100.0f - (AZ::GetAbs(halfFrames - static_cast<float>(leftNumFrames)) / numFrames) * 100.0f;
  330. // Get the maximum depth.
  331. size_t maxDepth = 0;
  332. for (const Node* node : m_nodes)
  333. {
  334. maxDepth = AZ::GetMax(maxDepth, node->m_dimension);
  335. }
  336. AZ_TracePrintf("Motion Matching", " KdTree Balance Info: leftSide=%d rightSide=%d score=%.2f totalFrames=%d maxDepth=%d", leftNumFrames, rightNumFrames, balanceScore, leftNumFrames + rightNumFrames, maxDepth);
  337. size_t numLeafNodes = 0;
  338. size_t numZeroNodes = 0;
  339. size_t minFrames = 1000000000;
  340. size_t maxFrames = 0;
  341. AZStd::string framesString;
  342. for (const Node* node : m_nodes)
  343. {
  344. if (node->m_leftNode || node->m_rightNode)
  345. {
  346. continue;
  347. }
  348. numLeafNodes++;
  349. if (node->m_frames.empty())
  350. {
  351. numZeroNodes++;
  352. }
  353. if (!framesString.empty())
  354. {
  355. framesString += ", ";
  356. }
  357. framesString += AZStd::to_string(node->m_frames.size());
  358. minFrames = AZ::GetMin(minFrames, node->m_frames.size());
  359. maxFrames = AZ::GetMax(maxFrames, node->m_frames.size());
  360. }
  361. AZ_TracePrintf("Motion Matching", " Frames = (%s)", framesString.c_str());
  362. const size_t avgFrames = (leftNumFrames + rightNumFrames) / numLeafNodes;
  363. AZ_TracePrintf("Motion Matching", " KdTree Node Info: leafs=%d avgFrames=%d zeroFrames=%d minFrames=%d maxFrames=%d", numLeafNodes, avgFrames, numZeroNodes, minFrames, maxFrames);
  364. #endif
  365. }
  366. void KdTree::FindNearestNeighbors(const AZStd::vector<float>& frameFloats, AZStd::vector<size_t>& resultFrameIndices) const
  367. {
  368. AZ_Assert(IsInitialized() && !m_nodes.empty(), "Expecting a valid and initialized kdTree. Did you forget to call KdTree::Init()?");
  369. Node* curNode = m_nodes[0];
  370. // Step as far as we need to through the kdTree.
  371. Node* nodeToSearch = nullptr;
  372. const size_t numDimensions = frameFloats.size();
  373. for (size_t d = 0; d < numDimensions; ++d)
  374. {
  375. AZ_Assert(curNode->m_dimension == d, "Dimension mismatch");
  376. // We have children in both directions.
  377. if (curNode->m_leftNode && curNode->m_rightNode)
  378. {
  379. curNode = (frameFloats[d] <= curNode->m_median) ? curNode->m_leftNode : curNode->m_rightNode;
  380. }
  381. else if (!curNode->m_leftNode && !curNode->m_rightNode) // we have a leaf node
  382. {
  383. nodeToSearch = curNode;
  384. }
  385. else
  386. {
  387. // We have either a left and right node, so we're not at a leaf yet.
  388. if (curNode->m_leftNode)
  389. {
  390. if (frameFloats[d] <= curNode->m_median)
  391. {
  392. curNode = curNode->m_leftNode;
  393. }
  394. else
  395. {
  396. nodeToSearch = curNode;
  397. }
  398. }
  399. else // We have a right node.
  400. {
  401. if (frameFloats[d] > curNode->m_median)
  402. {
  403. curNode = curNode->m_rightNode;
  404. }
  405. else
  406. {
  407. nodeToSearch = curNode;
  408. }
  409. }
  410. }
  411. // If we found our search node, perform a linear search through the frames inside this node.
  412. if (nodeToSearch)
  413. {
  414. //AZ_Assert(d == nodeToSearch->m_dimension, "Dimension mismatch inside kdTree nearest neighbor search.");
  415. FindNearestNeighbors(nodeToSearch, frameFloats, resultFrameIndices);
  416. return;
  417. }
  418. }
  419. FindNearestNeighbors(curNode, frameFloats, resultFrameIndices);
  420. }
  421. void KdTree::FindNearestNeighbors([[maybe_unused]] Node* node, [[maybe_unused]] const AZStd::vector<float>& frameFloats, AZStd::vector<size_t>& resultFrameIndices) const
  422. {
  423. resultFrameIndices = node->m_frames;
  424. }
  425. } // namespace EMotionFX::MotionMatching