MallocSpy.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #pragma once
  2. #include "catch.hpp"
  3. #include <objbase.h>
  4. #include <wil/wistd_functional.h>
  5. #include <wrl/implements.h>
  6. #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
  7. // IMallocSpy requires you to implement all methods, but we often only want one or two...
  8. struct MallocSpy : Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMallocSpy>
  9. {
  10. wistd::function<SIZE_T(SIZE_T)> PreAllocCallback;
  11. virtual SIZE_T STDMETHODCALLTYPE PreAlloc(SIZE_T requestSize) override
  12. {
  13. if (PreAllocCallback)
  14. {
  15. return PreAllocCallback(requestSize);
  16. }
  17. return requestSize;
  18. }
  19. wistd::function<void*(void*)> PostAllocCallback;
  20. virtual void* STDMETHODCALLTYPE PostAlloc(void* ptr) override
  21. {
  22. if (PostAllocCallback)
  23. {
  24. return PostAllocCallback(ptr);
  25. }
  26. return ptr;
  27. }
  28. wistd::function<void*(void*)> PreFreeCallback;
  29. virtual void* STDMETHODCALLTYPE PreFree(void* ptr, BOOL wasSpyed) override
  30. {
  31. if (wasSpyed && PreFreeCallback)
  32. {
  33. return PreFreeCallback(ptr);
  34. }
  35. return ptr;
  36. }
  37. virtual void STDMETHODCALLTYPE PostFree(BOOL /*wasSpyed*/) override
  38. {
  39. }
  40. wistd::function<SIZE_T(void*, SIZE_T, void**)> PreReallocCallback;
  41. virtual SIZE_T STDMETHODCALLTYPE PreRealloc(void* ptr, SIZE_T requestSize, void** newPtr, BOOL wasSpyed) override
  42. {
  43. *newPtr = ptr;
  44. if (wasSpyed && PreReallocCallback)
  45. {
  46. return PreReallocCallback(ptr, requestSize, newPtr);
  47. }
  48. return requestSize;
  49. }
  50. wistd::function<void*(void*)> PostReallocCallback;
  51. virtual void* STDMETHODCALLTYPE PostRealloc(void* ptr, BOOL wasSpyed) override
  52. {
  53. if (wasSpyed && PostReallocCallback)
  54. {
  55. return PostReallocCallback(ptr);
  56. }
  57. return ptr;
  58. }
  59. wistd::function<void*(void*)> PreGetSizeCallback;
  60. virtual void* STDMETHODCALLTYPE PreGetSize(void* ptr, BOOL wasSpyed) override
  61. {
  62. if (wasSpyed && PreGetSizeCallback)
  63. {
  64. return PreGetSizeCallback(ptr);
  65. }
  66. return ptr;
  67. }
  68. wistd::function<SIZE_T(SIZE_T)> PostGetSizeCallback;
  69. virtual SIZE_T STDMETHODCALLTYPE PostGetSize(SIZE_T size, BOOL wasSpyed) override
  70. {
  71. if (wasSpyed && PostGetSizeCallback)
  72. {
  73. return PostGetSizeCallback(size);
  74. }
  75. return size;
  76. }
  77. wistd::function<void*(void*)> PreDidAllocCallback;
  78. virtual void* STDMETHODCALLTYPE PreDidAlloc(void* ptr, BOOL wasSpyed) override
  79. {
  80. if (wasSpyed && PreDidAllocCallback)
  81. {
  82. return PreDidAllocCallback(ptr);
  83. }
  84. return ptr;
  85. }
  86. virtual int STDMETHODCALLTYPE PostDidAlloc(void* /*ptr*/, BOOL /*wasSpyed*/, int result) override
  87. {
  88. return result;
  89. }
  90. virtual void STDMETHODCALLTYPE PreHeapMinimize() override
  91. {
  92. }
  93. virtual void STDMETHODCALLTYPE PostHeapMinimize() override
  94. {
  95. }
  96. };
  97. Microsoft::WRL::ComPtr<MallocSpy> MakeSecureDeleterMallocSpy()
  98. {
  99. using namespace Microsoft::WRL;
  100. auto result = Make<MallocSpy>();
  101. REQUIRE(result);
  102. result->PreFreeCallback = [](void* ptr)
  103. {
  104. ComPtr<IMalloc> malloc;
  105. if (SUCCEEDED(::CoGetMalloc(1, &malloc)))
  106. {
  107. auto size = malloc->GetSize(ptr);
  108. auto buffer = static_cast<byte*>(ptr);
  109. for (size_t i = 0; i < size; ++i)
  110. {
  111. REQUIRE(buffer[i] == 0);
  112. }
  113. }
  114. return ptr;
  115. };
  116. return result;
  117. }
  118. #endif