so.cc 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #include <dlfcn.h>
  2. #include <assert.h>
  3. #include <unistd.h>
  4. #include <vtv_fail.h>
  5. extern "C" int printf(const char *, ...);
  6. extern "C" int sprintf(char *, const char*, ...);
  7. static int counter = 0;
  8. extern int failures;
  9. template <int i> struct base
  10. {
  11. virtual char * whoami() {
  12. static char sl[100];
  13. sprintf(sl, "I am base %d", i);
  14. return sl;
  15. }
  16. virtual void inc() { counter += i; }
  17. };
  18. template <int i> struct derived: base<i>
  19. {
  20. virtual char * whoami() {
  21. static char sl[100];
  22. sprintf(sl, "I am derived %d", i);
  23. return sl;
  24. }
  25. virtual void inc() { counter += (10*i); }
  26. };
  27. // We don't use this class. It is just here so that the
  28. // compiler does not devirtualize calls to derived::inc()
  29. template <int i> struct derived2: derived<i>
  30. {
  31. virtual void inc() { counter += (20*i); }
  32. };
  33. static base<TPID> * bp = new base<TPID>();
  34. static derived<TPID> * dp = new derived<TPID>();
  35. static base<TPID> * dbp = new derived<TPID>();
  36. // Given 2 pointers to C++ objects (non PODs), exchange the pointers to vtable
  37. static void exchange_vtptr(void * object1_ptr, void * object2_ptr)
  38. {
  39. void ** object1_vtptr_ptr = (void **)object1_ptr;
  40. void ** object2_vtptr_ptr = (void **)object2_ptr;
  41. void * object1_vtptr = *object1_vtptr_ptr;
  42. void * object2_vtptr = *object2_vtptr_ptr;
  43. *object1_vtptr_ptr = object2_vtptr;
  44. *object2_vtptr_ptr = object1_vtptr;
  45. }
  46. #define BUILD_NAME(NAME,ID) NAME##ID
  47. #define EXPAND(NAME,X) BUILD_NAME(NAME,X)
  48. extern "C" void EXPAND(so_entry_,TPID)(void)
  49. {
  50. int prev_counter;
  51. int prev_failures;
  52. counter = 0;
  53. bp->inc();
  54. dp->inc();
  55. dbp->inc();
  56. assert(counter == (TPID + 10*TPID + 10*TPID));
  57. prev_counter = counter;
  58. exchange_vtptr(bp, dp);
  59. bp->inc(); // This one should succeed but it is calling the wrong member
  60. if (counter != (prev_counter + 10*TPID))
  61. {
  62. printf("TPID=%d whoami=%s wrong counter value prev_counter=%d counter=%d\n", TPID, bp->whoami(), prev_counter, counter);
  63. sleep(2);
  64. }
  65. assert(counter == (prev_counter + 10*TPID));
  66. // printf("Pass first attack!\n");
  67. // This one should fail verification!. So it should jump to __vtv_verify_fail above.
  68. prev_failures = failures;
  69. dp->inc();
  70. // this code may be executed by multiple threads at the same time. So, just verify the number of failures has
  71. // increased as opposed to check for increase by 1.
  72. assert(failures > prev_failures);
  73. assert(counter == (prev_counter + 10*TPID + TPID));
  74. // printf("TPDI=%d counter %d\n", TPID, counter);
  75. // printf("Pass second attack!\n");
  76. // restore the vtable pointers to the original state.
  77. // This is very important. For some reason the dlclose is not "really" closing the library so when we reopen it we are
  78. // getting the old memory state.
  79. exchange_vtptr(bp, dp);
  80. }