ComApartmentVariableTests.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #include <wil/com_apartment_variable.h>
  2. #include <wil/com.h>
  3. #include <functional>
  4. #include "common.h"
  5. #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
  6. template <typename... args_t>
  7. inline void LogOutput(_Printf_format_string_ PCWSTR format, args_t&&... args)
  8. {
  9. OutputDebugStringW(wil::str_printf_failfast<wil::unique_cotaskmem_string>(format, wistd::forward<args_t>(args)...).get());
  10. }
  11. inline bool IsComInitialized()
  12. {
  13. APTTYPE type{}; APTTYPEQUALIFIER qualifier{};
  14. return CoGetApartmentType(&type, &qualifier) == S_OK;
  15. }
  16. inline void WaitForAllComApartmentsToRundown()
  17. {
  18. while (IsComInitialized())
  19. {
  20. Sleep(0);
  21. }
  22. }
  23. void co_wait(const wil::unique_event& e)
  24. {
  25. HANDLE raw[] = { e.get() };
  26. ULONG index{};
  27. REQUIRE_SUCCEEDED(CoWaitForMultipleHandles(COWAIT_DISPATCH_CALLS, INFINITE, static_cast<ULONG>(std::size(raw)), raw, &index));
  28. }
  29. void RunApartmentVariableTest(void(*test)())
  30. {
  31. test();
  32. // Apartment variable rundown is async, wait for the last COM apartment
  33. // to rundown before proceeding to the next test.
  34. WaitForAllComApartmentsToRundown();
  35. }
  36. struct mock_platform
  37. {
  38. static unsigned long long GetApartmentId()
  39. {
  40. APTTYPE type; APTTYPEQUALIFIER qualifer;
  41. REQUIRE_SUCCEEDED(CoGetApartmentType(&type, &qualifer)); // ensure COM is inited
  42. // Approximate apartment Id
  43. if (type == APTTYPE_STA)
  44. {
  45. REQUIRE_FALSE(GetCurrentThreadId() < APTTYPE_MAINSTA);
  46. return GetCurrentThreadId();
  47. }
  48. else
  49. {
  50. // APTTYPE_MTA (1), APTTYPE_NA (2), APTTYPE_MAINSTA (3)
  51. return type;
  52. }
  53. }
  54. static auto RegisterForApartmentShutdown(IApartmentShutdown* observer)
  55. {
  56. const auto id = GetApartmentId();
  57. auto apt_observers = m_observers.find(id);
  58. if (apt_observers == m_observers.end())
  59. {
  60. m_observers.insert({ id, { observer} });
  61. }
  62. else
  63. {
  64. apt_observers->second.emplace_back(observer);
  65. }
  66. return shutdown_type{ reinterpret_cast<APARTMENT_SHUTDOWN_REGISTRATION_COOKIE>(id) };
  67. }
  68. static void UnRegisterForApartmentShutdown(APARTMENT_SHUTDOWN_REGISTRATION_COOKIE cookie)
  69. {
  70. auto id = reinterpret_cast<unsigned long long>(cookie);
  71. m_observers.erase(id);
  72. }
  73. using shutdown_type = wil::unique_any<APARTMENT_SHUTDOWN_REGISTRATION_COOKIE, decltype(&UnRegisterForApartmentShutdown), UnRegisterForApartmentShutdown>;
  74. // This is needed to simulate the platform for unit testing.
  75. static auto CoInitializeEx(DWORD coinitFlags = 0 /*COINIT_MULTITHREADED*/)
  76. {
  77. return wil::scope_exit([aptId = GetCurrentThreadId(), init = wil::CoInitializeEx(coinitFlags)]()
  78. {
  79. const auto id = GetApartmentId();
  80. auto apt_observers = m_observers.find(id);
  81. if (apt_observers != m_observers.end())
  82. {
  83. const auto& observers = apt_observers->second;
  84. for (auto& observer : observers)
  85. {
  86. observer->OnUninitialize(id);
  87. }
  88. m_observers.erase(apt_observers);
  89. }
  90. });
  91. }
  92. // Enable the test hook to force losing the race
  93. inline static constexpr unsigned long AsyncRundownDelayForTestingRaces = 1; // enable test hook
  94. inline static std::unordered_map<unsigned long long, std::vector<wil::com_ptr<IApartmentShutdown>>> m_observers;
  95. };
  96. auto fn() { return 42; };
  97. auto fn2() { return 43; };
  98. wil::apartment_variable<int, wil::apartment_variable_leak_action::ignore, mock_platform> g_v1;
  99. wil::apartment_variable<int, wil::apartment_variable_leak_action::ignore> g_v2;
  100. template <typename platform = wil::apartment_variable_platform>
  101. void TestApartmentVariableAllMethods()
  102. {
  103. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  104. std::ignore = g_v1.get_or_create(fn);
  105. wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform> v1;
  106. REQUIRE(v1.get_if() == nullptr);
  107. REQUIRE(v1.get_or_create(fn) == 42);
  108. int value = 43;
  109. v1.set(value);
  110. REQUIRE(v1.get_or_create(fn) == 43);
  111. REQUIRE(v1.get_existing() == 43);
  112. v1.clear();
  113. REQUIRE(v1.get_if() == nullptr);
  114. }
  115. template <typename platform = wil::apartment_variable_platform>
  116. void TestApartmentVariableGetOrCreateForms()
  117. {
  118. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  119. wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform> v1;
  120. REQUIRE(v1.get_or_create(fn) == 42);
  121. v1.clear();
  122. REQUIRE(v1.get_or_create([&]
  123. {
  124. return 1;
  125. }) == 1);
  126. v1.clear();
  127. REQUIRE(v1.get_or_create() == 0);
  128. }
  129. template <typename platform = wil::apartment_variable_platform>
  130. void TestApartmentVariableLifetimes()
  131. {
  132. wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform> av1, av2;
  133. {
  134. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  135. auto v1 = av1.get_or_create(fn);
  136. REQUIRE(av1.storage().size() == 1);
  137. auto v2 = av1.get_existing();
  138. REQUIRE(av1.current_apartment_variable_count() == 1);
  139. REQUIRE(v1 == v2);
  140. }
  141. {
  142. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  143. auto v1 = av1.get_or_create(fn);
  144. auto v2 = av2.get_or_create(fn2);
  145. REQUIRE((av1.current_apartment_variable_count() == 2));
  146. REQUIRE(v1 != v2);
  147. REQUIRE(av1.storage().size() == 1);
  148. }
  149. REQUIRE(av1.storage().size() == 0);
  150. {
  151. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  152. auto v = av1.get_or_create(fn);
  153. REQUIRE(av1.current_apartment_variable_count() == 1);
  154. std::thread([&]() // join below makes this ok
  155. {
  156. SetThreadDescription(GetCurrentThread(), L"STA");
  157. auto coUninit = platform::CoInitializeEx(COINIT_APARTMENTTHREADED);
  158. std::ignore = av1.get_or_create(fn);
  159. REQUIRE(av1.storage().size() == 2);
  160. REQUIRE(av1.current_apartment_variable_count() == 1);
  161. }).join();
  162. REQUIRE(av1.storage().size() == 1);
  163. av1.get_or_create(fn)++;
  164. v = av1.get_existing();
  165. REQUIRE(v == 43);
  166. }
  167. {
  168. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  169. std::ignore = av1.get_or_create(fn);
  170. REQUIRE(av1.current_apartment_variable_count() == 1);
  171. int i = 1;
  172. av1.set(i);
  173. av1.clear();
  174. REQUIRE(av1.current_apartment_variable_count() == 0);
  175. // will fail fast since clear() was called.
  176. // av1.set(1);
  177. av1.clear_all_apartments_async().get();
  178. }
  179. REQUIRE(av1.storage().size() == 0);
  180. }
  181. template <typename platform = wil::apartment_variable_platform>
  182. void TestMultipleApartments()
  183. {
  184. wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform> av1, av2;
  185. wil::unique_event t1Created{ wil::EventOptions::None }, t2Created{ wil::EventOptions::None };
  186. wil::unique_event t1Shutdown{ wil::EventOptions::None }, t2Shutdown{ wil::EventOptions::None };
  187. auto apt1_thread = std::thread([&]() // join below makes this ok
  188. {
  189. SetThreadDescription(GetCurrentThread(), L"STA 1");
  190. auto coUninit = platform::CoInitializeEx(COINIT_APARTMENTTHREADED);
  191. std::ignore = av1.get_or_create(fn);
  192. std::ignore = av2.get_or_create(fn);
  193. t1Created.SetEvent();
  194. co_wait(t1Shutdown);
  195. });
  196. auto apt2_thread = std::thread([&]() // join below makes this ok
  197. {
  198. SetThreadDescription(GetCurrentThread(), L"STA 2");
  199. auto coUninit = platform::CoInitializeEx(COINIT_APARTMENTTHREADED);
  200. std::ignore = av1.get_or_create(fn);
  201. std::ignore = av2.get_or_create(fn);
  202. t2Created.SetEvent();
  203. co_wait(t2Shutdown);
  204. });
  205. t1Created.wait();
  206. t2Created.wait();
  207. av1.clear_all_apartments_async().get();
  208. av2.clear_all_apartments_async().get();
  209. t1Shutdown.SetEvent();
  210. t2Shutdown.SetEvent();
  211. apt1_thread.join();
  212. apt2_thread.join();
  213. REQUIRE((wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform>::storage().size() == 0));
  214. }
  215. template <typename platform = wil::apartment_variable_platform>
  216. void TestWinningApartmentAlreadyRundownRace()
  217. {
  218. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  219. wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform> av;
  220. std::ignore = av.get_or_create(fn);
  221. const auto& storage = av.storage(); // for viewing the storage in the debugger
  222. wil::unique_event otherAptVarCreated{ wil::EventOptions::None };
  223. wil::unique_event startApartmentRundown{ wil::EventOptions::None };
  224. wil::unique_event comRundownComplete{ wil::EventOptions::None };
  225. auto apt_thread = std::thread([&]() // join below makes this ok
  226. {
  227. SetThreadDescription(GetCurrentThread(), L"STA");
  228. auto coUninit = platform::CoInitializeEx(COINIT_APARTMENTTHREADED);
  229. std::ignore = av.get_or_create(fn);
  230. otherAptVarCreated.SetEvent();
  231. co_wait(startApartmentRundown);
  232. });
  233. otherAptVarCreated.wait();
  234. // we now have av in this apartment and in the STA
  235. REQUIRE(storage.size() == 2);
  236. // wait for async clean to complete
  237. av.clear_all_apartments_async().get();
  238. startApartmentRundown.SetEvent();
  239. REQUIRE(av.storage().size() == 0);
  240. apt_thread.join();
  241. }
  242. template <typename platform = wil::apartment_variable_platform>
  243. void TestLosingApartmentAlreadyRundownRace()
  244. {
  245. auto coUninit = platform::CoInitializeEx(COINIT_MULTITHREADED);
  246. wil::apartment_variable<int, wil::apartment_variable_leak_action::fail_fast, platform> av;
  247. std::ignore = av.get_or_create(fn);
  248. const auto& storage = av.storage(); // for viewing the storage in the debugger
  249. wil::unique_event otherAptVarCreated{ wil::EventOptions::None };
  250. wil::unique_event startApartmentRundown{ wil::EventOptions::None };
  251. wil::unique_event comRundownComplete{ wil::EventOptions::None };
  252. auto apt_thread = std::thread([&]() // join below makes this ok
  253. {
  254. SetThreadDescription(GetCurrentThread(), L"STA");
  255. auto coUninit = platform::CoInitializeEx(COINIT_APARTMENTTHREADED);
  256. std::ignore = av.get_or_create(fn);
  257. otherAptVarCreated.SetEvent();
  258. co_wait(startApartmentRundown);
  259. coUninit.reset();
  260. comRundownComplete.SetEvent();
  261. });
  262. otherAptVarCreated.wait();
  263. // we now have av in this apartment and in the STA
  264. REQUIRE(storage.size() == 2);
  265. auto clearAllOperation = av.clear_all_apartments_async();
  266. startApartmentRundown.SetEvent();
  267. comRundownComplete.wait();
  268. clearAllOperation.get(); // wait for the async rundowns to complete
  269. REQUIRE(av.storage().size() == 0);
  270. apt_thread.join();
  271. }
  272. TEST_CASE("ComApartmentVariable::ShutdownRegistration", "[LocalOnly][com][unique_apartment_shutdown_registration]")
  273. {
  274. {
  275. wil::unique_apartment_shutdown_registration r;
  276. }
  277. {
  278. auto coUninit = wil::CoInitializeEx(COINIT_MULTITHREADED);
  279. struct ApartmentObserver : public winrt::implements<ApartmentObserver, IApartmentShutdown>
  280. {
  281. void STDMETHODCALLTYPE OnUninitialize(unsigned long long apartmentId) noexcept override
  282. {
  283. LogOutput(L"OnUninitialize %ull\n", apartmentId);
  284. }
  285. };
  286. wil::unique_apartment_shutdown_registration apt_shutdown_registration;
  287. unsigned long long id{};
  288. REQUIRE_SUCCEEDED(::RoRegisterForApartmentShutdown(winrt::make<ApartmentObserver>().get(), &id, apt_shutdown_registration.put()));
  289. LogOutput(L"RoRegisterForApartmentShutdown %p\r\n", apt_shutdown_registration.get());
  290. // don't unregister and let the pending COM apartment rundown invoke the callback.
  291. apt_shutdown_registration.release();
  292. }
  293. }
  294. TEST_CASE("ComApartmentVariable::CallAllMethods", "[com][apartment_variable]")
  295. {
  296. RunApartmentVariableTest(TestApartmentVariableAllMethods<mock_platform>);
  297. }
  298. TEST_CASE("ComApartmentVariable::GetOrCreateForms", "[com][apartment_variable]")
  299. {
  300. RunApartmentVariableTest(TestApartmentVariableGetOrCreateForms<mock_platform>);
  301. }
  302. TEST_CASE("ComApartmentVariable::VariableLifetimes", "[com][apartment_variable]")
  303. {
  304. RunApartmentVariableTest(TestApartmentVariableLifetimes<mock_platform>);
  305. }
  306. TEST_CASE("ComApartmentVariable::WinningApartmentAlreadyRundownRace", "[com][apartment_variable]")
  307. {
  308. RunApartmentVariableTest(TestWinningApartmentAlreadyRundownRace<mock_platform>);
  309. }
  310. TEST_CASE("ComApartmentVariable::LosingApartmentAlreadyRundownRace", "[com][apartment_variable]")
  311. {
  312. RunApartmentVariableTest(TestLosingApartmentAlreadyRundownRace<mock_platform>);
  313. }
  314. TEST_CASE("ComApartmentVariable::MultipleApartments", "[com][apartment_variable]")
  315. {
  316. RunApartmentVariableTest(TestMultipleApartments<mock_platform>);
  317. }
  318. TEST_CASE("ComApartmentVariable::UseRealPlatformRunAllTests", "[com][apartment_variable]")
  319. {
  320. if (!wil::are_apartment_variables_supported())
  321. {
  322. return;
  323. }
  324. RunApartmentVariableTest(TestApartmentVariableAllMethods);
  325. RunApartmentVariableTest(TestApartmentVariableGetOrCreateForms);
  326. RunApartmentVariableTest(TestApartmentVariableLifetimes);
  327. RunApartmentVariableTest(TestWinningApartmentAlreadyRundownRace);
  328. RunApartmentVariableTest(TestLosingApartmentAlreadyRundownRace);
  329. RunApartmentVariableTest(TestMultipleApartments);
  330. }
  331. #endif