modulo25519-optimized.hpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #pragma once
  2. #include <nall/arithmetic/barrett.hpp>
  3. namespace nall::EllipticCurve {
  4. static const uint256_t P = (1_u256 << 255) - 19;
  5. #define Mask ((1ull << 51) - 1)
  6. struct Modulo25519 {
  7. Modulo25519() = default;
  8. Modulo25519(const Modulo25519&) = default;
  9. Modulo25519(uint64_t a, uint64_t b = 0, uint64_t c = 0, uint64_t d = 0, uint64_t e = 0) : l{a, b, c, d, e} {}
  10. Modulo25519(uint256_t n);
  11. explicit operator bool() const { return (bool)operator()(); }
  12. auto operator[](uint index) -> uint64_t& { return l[index]; }
  13. auto operator[](uint index) const -> uint64_t { return l[index]; }
  14. auto operator()() const -> uint256_t;
  15. private:
  16. uint64_t l[5]; //51-bits per limb; 255-bits total
  17. };
  18. inline Modulo25519::Modulo25519(uint256_t n) {
  19. l[0] = n >> 0 & Mask;
  20. l[1] = n >> 51 & Mask;
  21. l[2] = n >> 102 & Mask;
  22. l[3] = n >> 153 & Mask;
  23. l[4] = n >> 204 & Mask;
  24. }
  25. inline auto Modulo25519::operator()() const -> uint256_t {
  26. Modulo25519 o = *this;
  27. o[1] += (o[0] >> 51); o[0] &= Mask;
  28. o[2] += (o[1] >> 51); o[1] &= Mask;
  29. o[3] += (o[2] >> 51); o[2] &= Mask;
  30. o[4] += (o[3] >> 51); o[3] &= Mask;
  31. o[0] += 19 * (o[4] >> 51); o[4] &= Mask;
  32. o[1] += (o[0] >> 51); o[0] &= Mask;
  33. o[2] += (o[1] >> 51); o[1] &= Mask;
  34. o[3] += (o[2] >> 51); o[2] &= Mask;
  35. o[4] += (o[3] >> 51); o[3] &= Mask;
  36. o[0] += 19 * (o[4] >> 51); o[4] &= Mask;
  37. o[0] += 19;
  38. o[1] += (o[0] >> 51); o[0] &= Mask;
  39. o[2] += (o[1] >> 51); o[1] &= Mask;
  40. o[3] += (o[2] >> 51); o[2] &= Mask;
  41. o[4] += (o[3] >> 51); o[3] &= Mask;
  42. o[0] += 19 * (o[4] >> 51); o[4] &= Mask;
  43. o[0] += Mask - 18;
  44. o[1] += Mask;
  45. o[2] += Mask;
  46. o[3] += Mask;
  47. o[4] += Mask;
  48. o[1] += o[0] >> 51; o[0] &= Mask;
  49. o[2] += o[1] >> 51; o[1] &= Mask;
  50. o[3] += o[2] >> 51; o[2] &= Mask;
  51. o[4] += o[3] >> 51; o[3] &= Mask;
  52. o[4] &= Mask;
  53. return (uint256_t)o[0] << 0 | (uint256_t)o[1] << 51 | (uint256_t)o[2] << 102 | (uint256_t)o[3] << 153 | (uint256_t)o[4] << 204;
  54. }
  55. inline auto cmove(bool move, Modulo25519& l, const Modulo25519& r) -> void {
  56. uint64_t mask = -move;
  57. l[0] ^= mask & (l[0] ^ r[0]);
  58. l[1] ^= mask & (l[1] ^ r[1]);
  59. l[2] ^= mask & (l[2] ^ r[2]);
  60. l[3] ^= mask & (l[3] ^ r[3]);
  61. l[4] ^= mask & (l[4] ^ r[4]);
  62. }
  63. inline auto cswap(bool swap, Modulo25519& l, Modulo25519& r) -> void {
  64. uint64_t mask = -swap, x;
  65. x = mask & (l[0] ^ r[0]); l[0] ^= x; r[0] ^= x;
  66. x = mask & (l[1] ^ r[1]); l[1] ^= x; r[1] ^= x;
  67. x = mask & (l[2] ^ r[2]); l[2] ^= x; r[2] ^= x;
  68. x = mask & (l[3] ^ r[3]); l[3] ^= x; r[3] ^= x;
  69. x = mask & (l[4] ^ r[4]); l[4] ^= x; r[4] ^= x;
  70. }
  71. inline auto operator-(const Modulo25519& l) -> Modulo25519 { //P - l
  72. Modulo25519 o;
  73. uint64_t c;
  74. o[0] = 0xfffffffffffda - l[0]; c = o[0] >> 51; o[0] &= Mask;
  75. o[1] = 0xffffffffffffe - l[1] + c; c = o[1] >> 51; o[1] &= Mask;
  76. o[2] = 0xffffffffffffe - l[2] + c; c = o[2] >> 51; o[2] &= Mask;
  77. o[3] = 0xffffffffffffe - l[3] + c; c = o[3] >> 51; o[3] &= Mask;
  78. o[4] = 0xffffffffffffe - l[4] + c; c = o[4] >> 51; o[4] &= Mask;
  79. o[0] += c * 19;
  80. return o;
  81. }
  82. inline auto operator+(const Modulo25519& l, const Modulo25519& r) -> Modulo25519 {
  83. Modulo25519 o;
  84. uint64_t c;
  85. o[0] = l[0] + r[0]; c = o[0] >> 51; o[0] &= Mask;
  86. o[1] = l[1] + r[1] + c; c = o[1] >> 51; o[1] &= Mask;
  87. o[2] = l[2] + r[2] + c; c = o[2] >> 51; o[2] &= Mask;
  88. o[3] = l[3] + r[3] + c; c = o[3] >> 51; o[3] &= Mask;
  89. o[4] = l[4] + r[4] + c; c = o[4] >> 51; o[4] &= Mask;
  90. o[0] += c * 19;
  91. return o;
  92. }
  93. inline auto operator-(const Modulo25519& l, const Modulo25519& r) -> Modulo25519 {
  94. Modulo25519 o;
  95. uint64_t c;
  96. o[0] = l[0] + 0x1fffffffffffb4 - r[0]; c = o[0] >> 51; o[0] &= Mask;
  97. o[1] = l[1] + 0x1ffffffffffffc - r[1] + c; c = o[1] >> 51; o[1] &= Mask;
  98. o[2] = l[2] + 0x1ffffffffffffc - r[2] + c; c = o[2] >> 51; o[2] &= Mask;
  99. o[3] = l[3] + 0x1ffffffffffffc - r[3] + c; c = o[3] >> 51; o[3] &= Mask;
  100. o[4] = l[4] + 0x1ffffffffffffc - r[4] + c; c = o[4] >> 51; o[4] &= Mask;
  101. o[0] += c * 19;
  102. return o;
  103. }
  104. inline auto operator*(const Modulo25519& l, uint64_t scalar) -> Modulo25519 {
  105. Modulo25519 o;
  106. uint128_t a;
  107. a = (uint128_t)l[0] * scalar; o[0] = a & Mask;
  108. a = (uint128_t)l[1] * scalar + (a >> 51 & Mask); o[1] = a & Mask;
  109. a = (uint128_t)l[2] * scalar + (a >> 51 & Mask); o[2] = a & Mask;
  110. a = (uint128_t)l[3] * scalar + (a >> 51 & Mask); o[3] = a & Mask;
  111. a = (uint128_t)l[4] * scalar + (a >> 51 & Mask); o[4] = a & Mask;
  112. o[0] += (a >> 51) * 19;
  113. return o;
  114. }
  115. inline auto operator*(const Modulo25519& l, Modulo25519 r) -> Modulo25519 {
  116. uint128_t t[] = {
  117. (uint128_t)r[0] * l[0],
  118. (uint128_t)r[0] * l[1] + (uint128_t)r[1] * l[0],
  119. (uint128_t)r[0] * l[2] + (uint128_t)r[1] * l[1] + (uint128_t)r[2] * l[0],
  120. (uint128_t)r[0] * l[3] + (uint128_t)r[1] * l[2] + (uint128_t)r[2] * l[1] + (uint128_t)r[3] * l[0],
  121. (uint128_t)r[0] * l[4] + (uint128_t)r[1] * l[3] + (uint128_t)r[2] * l[2] + (uint128_t)r[3] * l[1] + (uint128_t)r[4] * l[0]
  122. };
  123. r[1] *= 19, r[2] *= 19, r[3] *= 19, r[4] *= 19;
  124. t[0] += (uint128_t)r[4] * l[1] + (uint128_t)r[3] * l[2] + (uint128_t)r[2] * l[3] + (uint128_t)r[1] * l[4];
  125. t[1] += (uint128_t)r[4] * l[2] + (uint128_t)r[3] * l[3] + (uint128_t)r[2] * l[4];
  126. t[2] += (uint128_t)r[4] * l[3] + (uint128_t)r[3] * l[4];
  127. t[3] += (uint128_t)r[4] * l[4];
  128. uint64_t c; r[0] = t[0] & Mask; c = (uint64_t)(t[0] >> 51);
  129. t[1] += c; r[1] = t[1] & Mask; c = (uint64_t)(t[1] >> 51);
  130. t[2] += c; r[2] = t[2] & Mask; c = (uint64_t)(t[2] >> 51);
  131. t[3] += c; r[3] = t[3] & Mask; c = (uint64_t)(t[3] >> 51);
  132. t[4] += c; r[4] = t[4] & Mask; c = (uint64_t)(t[4] >> 51);
  133. r[0] += c * 19; c = r[0] >> 51; r[0] &= Mask;
  134. r[1] += c; c = r[1] >> 51; r[1] &= Mask;
  135. r[2] += c;
  136. return r;
  137. }
  138. inline auto operator&(const Modulo25519& lhs, uint256_t rhs) -> uint256_t {
  139. return lhs() & rhs;
  140. }
  141. inline auto square(const Modulo25519& lhs) -> Modulo25519 {
  142. Modulo25519 r{lhs};
  143. Modulo25519 d{r[0] * 2, r[1] * 2, r[2] * 2 * 19, r[4] * 19, r[4] * 19 * 2};
  144. uint128_t t[5];
  145. t[0] = (uint128_t)r[0] * r[0] + (uint128_t)d[4] * r[1] + (uint128_t)d[2] * r[3];
  146. t[1] = (uint128_t)d[0] * r[1] + (uint128_t)d[4] * r[2] + (uint128_t)r[3] * r[3] * 19;
  147. t[2] = (uint128_t)d[0] * r[2] + (uint128_t)r[1] * r[1] + (uint128_t)d[4] * r[3];
  148. t[3] = (uint128_t)d[0] * r[3] + (uint128_t)d[1] * r[2] + (uint128_t)r[4] * d[3];
  149. t[4] = (uint128_t)d[0] * r[4] + (uint128_t)d[1] * r[3] + (uint128_t)r[2] * r[2];
  150. uint64_t c; r[0] = t[0] & Mask; c = (uint64_t)(t[0] >> 51);
  151. t[1] += c; r[1] = t[1] & Mask; c = (uint64_t)(t[1] >> 51);
  152. t[2] += c; r[2] = t[2] & Mask; c = (uint64_t)(t[2] >> 51);
  153. t[3] += c; r[3] = t[3] & Mask; c = (uint64_t)(t[3] >> 51);
  154. t[4] += c; r[4] = t[4] & Mask; c = (uint64_t)(t[4] >> 51);
  155. r[0] += c * 19; c = r[0] >> 51; r[0] &= Mask;
  156. r[1] += c; c = r[1] >> 51; r[1] &= Mask;
  157. r[2] += c;
  158. return r;
  159. }
  160. inline auto exponentiate(const Modulo25519& lhs, uint256_t exponent) -> Modulo25519 {
  161. Modulo25519 x = 1, y;
  162. for(uint bit : reverse(range(256))) {
  163. x = square(x);
  164. y = x * lhs;
  165. cmove(exponent >> bit & 1, x, y);
  166. }
  167. return x;
  168. }
  169. inline auto reciprocal(const Modulo25519& lhs) -> Modulo25519 {
  170. return exponentiate(lhs, P - 2);
  171. }
  172. inline auto squareRoot(const Modulo25519& lhs) -> Modulo25519 {
  173. static const Modulo25519 I = exponentiate(Modulo25519(2), P - 1 >> 2); //I == sqrt(-1)
  174. Modulo25519 x = exponentiate(lhs, P + 3 >> 3);
  175. Modulo25519 y = x * I;
  176. cmove(bool(square(x) - lhs), x, y);
  177. y = -x;
  178. cmove(x & 1, x, y);
  179. return x;
  180. }
  181. #undef Mask
  182. }