main.hpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. // https://cirosantilli.com/linux-kernel-module-cheat#atomic-cpp
  2. #if __cplusplus >= 201103L
  3. #include <atomic>
  4. #include <cassert>
  5. #include <iostream>
  6. #include <mutex>
  7. #include <thread>
  8. #include <vector>
  9. size_t niters;
  10. #if LKMC_USERLAND_ATOMIC_STD_ATOMIC
  11. std::atomic_ulong global(0);
  12. #else
  13. uint64_t global = 0;
  14. #endif
  15. #if LKMC_USERLAND_ATOMIC_MUTEX
  16. std::mutex mutex;
  17. #endif
  18. void threadMain() {
  19. for (size_t i = 0; i < niters; ++i) {
  20. #if LKMC_USERLAND_ATOMIC_MUTEX
  21. mutex.lock();
  22. #endif
  23. #if LKMC_USERLAND_ATOMIC_X86_64_INC
  24. __asm__ __volatile__ (
  25. "incq %0;"
  26. : "+g" (global),
  27. "+g" (i) // to prevent loop unrolling, and make results more comparable across methods,
  28. // see also: https://cirosantilli.com/linux-kernel-module-cheat#infinite-busy-loop
  29. :
  30. :
  31. );
  32. #elif LKMC_USERLAND_ATOMIC_X86_64_LOCK_INC
  33. // https://cirosantilli.com/linux-kernel-module-cheat#x86-lock-prefix
  34. __asm__ __volatile__ (
  35. "lock incq %0;"
  36. : "+m" (global),
  37. "+g" (i) // to prevent loop unrolling
  38. :
  39. :
  40. );
  41. #elif LKMC_USERLAND_ATOMIC_AARCH64_ADD
  42. __asm__ __volatile__ (
  43. "add %0, %0, 1;"
  44. : "+r" (global),
  45. "+g" (i) // to prevent loop unrolling
  46. :
  47. :
  48. );
  49. #elif LKMC_USERLAND_ATOMIC_AARCH64_LDADD
  50. // https://cirosantilli.com/linux-kernel-module-cheat#arm-lse
  51. __asm__ __volatile__ (
  52. "ldadd %[inc], xzr, [%[addr]];"
  53. : "=m" (global),
  54. "+g" (i) // to prevent loop unrolling
  55. : [inc] "r" (1),
  56. [addr] "r" (&global)
  57. :
  58. );
  59. #else
  60. __asm__ __volatile__ (
  61. ""
  62. : "+g" (i) // to prevent he loop from being optimized to a single add
  63. // see also: https://stackoverflow.com/questions/37786547/enforcing-statement-order-in-c/56865717#56865717
  64. : "g" (global)
  65. :
  66. );
  67. global++;
  68. #endif
  69. #if LKMC_USERLAND_ATOMIC_MUTEX
  70. mutex.unlock();
  71. #endif
  72. }
  73. }
  74. #endif
  75. int main(int argc, char **argv) {
  76. #if __cplusplus >= 201103L
  77. size_t nthreads;
  78. if (argc > 1) {
  79. nthreads = std::stoull(argv[1], NULL, 0);
  80. } else {
  81. nthreads = 2;
  82. }
  83. if (argc > 2) {
  84. niters = std::stoull(argv[2], NULL, 0);
  85. } else {
  86. niters = 10;
  87. }
  88. std::vector<std::thread> threads(nthreads);
  89. for (size_t i = 0; i < nthreads; ++i)
  90. threads[i] = std::thread(threadMain);
  91. for (size_t i = 0; i < nthreads; ++i)
  92. threads[i].join();
  93. uint64_t expect = nthreads * niters;
  94. #if LKMC_USERLAND_ATOMIC_FAIL || \
  95. LKMC_USERLAND_ATOMIC_X86_64_INC || \
  96. LKMC_USERLAND_ATOMIC_AARCH64_INC
  97. // These fail, so we just print the outcomes.
  98. std::cout << "expect " << expect << std::endl;
  99. std::cout << "global " << global << std::endl;
  100. #else
  101. assert(global == expect);
  102. #endif
  103. #endif
  104. }