reed-solomon.hpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #pragma once
  2. namespace nall {
  3. //RS(n,k) = ReedSolomon<Length, Inputs>
  4. template<uint Length, uint Inputs>
  5. struct ReedSolomon {
  6. enum : uint { Parity = Length - Inputs };
  7. static_assert(Length <= 255 && Length > 0);
  8. static_assert(Parity <= 32 && Parity > 0);
  9. using Field = GaloisField<uint8_t, 255, 0x11d>;
  10. template<uint Rows, uint Cols = 1> using Polynomial = Matrix<Field, Rows, Cols>;
  11. template<uint Size>
  12. static auto shift(Polynomial<Size> polynomial) -> Polynomial<Size> {
  13. for(int n = Size - 1; n > 0; n--) polynomial[n] = polynomial[n - 1];
  14. polynomial[0] = 0;
  15. return polynomial;
  16. }
  17. template<uint Size>
  18. static auto degree(const Polynomial<Size>& polynomial) -> uint {
  19. for(int n = Size; n > 0; n--) {
  20. if(polynomial[n - 1] != 0) return n - 1;
  21. }
  22. return 0;
  23. }
  24. template<uint Size>
  25. static auto evaluate(const Polynomial<Size>& polynomial, Field field) -> Field {
  26. Field sum = 0;
  27. for(uint n : range(Size)) sum += polynomial[n] * field.pow(n);
  28. return sum;
  29. }
  30. Polynomial<Length> message;
  31. Polynomial<Parity> syndromes;
  32. Polynomial<Parity + 1> locators;
  33. ReedSolomon() = default;
  34. ReedSolomon(const ReedSolomon&) = default;
  35. ReedSolomon(const initializer_list<uint8_t>& source) {
  36. uint index = 0;
  37. for(auto& value : source) {
  38. if(index >= Length) break;
  39. message[index++] = value;
  40. }
  41. }
  42. auto operator[](uint index) -> Field& { return message[index]; }
  43. auto operator[](uint index) const -> Field { return message[index]; }
  44. auto calculateSyndromes() -> void {
  45. static const Polynomial<Parity> bases = [] {
  46. Polynomial<Parity> bases;
  47. for(uint n : range(Parity)) {
  48. bases[n] = Field::exp(n);
  49. }
  50. return bases;
  51. }();
  52. syndromes = {};
  53. for(uint m : range(Length)) {
  54. for(uint p : range(Parity)) {
  55. syndromes[p] *= bases[p];
  56. syndromes[p] += message[m];
  57. }
  58. }
  59. }
  60. auto generateParity() -> void {
  61. static const Polynomial<Parity, Parity> matrix = [] {
  62. Polynomial<Parity, Parity> matrix{};
  63. for(uint row : range(Parity)) {
  64. for(uint col : range(Parity)) {
  65. matrix(row, col) = Field::exp(row * col);
  66. }
  67. }
  68. if(auto result = matrix.invert()) return *result;
  69. throw; //should never occur
  70. }();
  71. for(uint p : range(Parity)) message[Inputs + p] = 0;
  72. calculateSyndromes();
  73. auto parity = matrix * syndromes;
  74. for(uint p : range(Parity)) message[Inputs + p] = parity[Parity - (p + 1)];
  75. }
  76. auto syndromesAreZero() -> bool {
  77. for(uint p : range(Parity)) {
  78. if(syndromes[p]) return false;
  79. }
  80. return true;
  81. }
  82. //algorithm: Berlekamp-Massey
  83. auto calculateLocators() -> void {
  84. Polynomial<Parity + 1> history{1};
  85. locators = history;
  86. uint errors = 0;
  87. for(uint n : range(Parity)) {
  88. Field discrepancy = 0;
  89. for(uint l : range(errors + 1)) {
  90. discrepancy += locators[l] * syndromes[n - l];
  91. }
  92. history = shift(history);
  93. if(discrepancy) {
  94. auto located = locators - history * discrepancy;
  95. if(errors * 2 <= n) {
  96. errors = (n + 1) - errors;
  97. history = locators * discrepancy.inv();
  98. }
  99. locators = located;
  100. }
  101. }
  102. }
  103. //algorithm: brute force
  104. //todo: implement Chien search here
  105. auto calculateErrors() -> vector<uint8_t> {
  106. calculateSyndromes();
  107. if(syndromesAreZero()) return {}; //no errors detected
  108. calculateLocators();
  109. vector<uint8_t> errors;
  110. for(uint n : range(Length)) {
  111. if(evaluate(locators, Field{2}.pow(255 - n))) continue;
  112. errors.append(Length - (n + 1));
  113. }
  114. return errors;
  115. }
  116. template<uint Size>
  117. static auto calculateErasures(array_view<uint8_t> errors) -> maybe<Polynomial<Size, Size>> {
  118. Polynomial<Size, Size> matrix{};
  119. for(uint row : range(Size)) {
  120. for(uint col : range(Size)) {
  121. uint index = Length - (errors[col] + 1);
  122. matrix(row, col) = Field::exp(row * index);
  123. }
  124. }
  125. return matrix.invert();
  126. }
  127. template<uint Size>
  128. auto correctErasures(array_view<uint8_t> errors) -> int {
  129. calculateSyndromes();
  130. if(syndromesAreZero()) return 0; //no errors detected
  131. if(auto matrix = calculateErasures<Size>(errors)) {
  132. Polynomial<Size> factors;
  133. for(uint n : range(Size)) factors[n] = syndromes[n];
  134. auto errata = matrix() * factors;
  135. for(uint m : range(Size)) {
  136. message[errors[m]] += errata[m];
  137. }
  138. calculateSyndromes();
  139. if(syndromesAreZero()) return Size; //corrected Size errors
  140. return -Size; //failed to correct Size errors
  141. }
  142. return -Size; //should never occur, but might ...
  143. }
  144. //note: the erasure matrix is generated as a Polynomial<NxN>, where N is the number of errors to correct.
  145. //because this is a template parameter, and the actual number of errors may very, this function is needed.
  146. //the alternative would be to convert Matrix<Rows, Cols> to a dynamically sized Matrix(Rows, Cols) type,
  147. //but this would require heap memory allocations and would be a massive performance penalty.
  148. auto correctErrata(array_view<uint8_t> errors) -> int {
  149. if(errors.size() >= Parity) return -errors.size(); //too many errors to be correctable
  150. switch(errors.size()) {
  151. case 0: return 0;
  152. case 1: return correctErasures< 1>(errors);
  153. case 2: return correctErasures< 2>(errors);
  154. case 3: return correctErasures< 3>(errors);
  155. case 4: return correctErasures< 4>(errors);
  156. case 5: return correctErasures< 5>(errors);
  157. case 6: return correctErasures< 6>(errors);
  158. case 7: return correctErasures< 7>(errors);
  159. case 8: return correctErasures< 8>(errors);
  160. case 9: return correctErasures< 9>(errors);
  161. case 10: return correctErasures<10>(errors);
  162. case 11: return correctErasures<11>(errors);
  163. case 12: return correctErasures<12>(errors);
  164. case 13: return correctErasures<13>(errors);
  165. case 14: return correctErasures<14>(errors);
  166. case 15: return correctErasures<15>(errors);
  167. case 16: return correctErasures<16>(errors);
  168. case 17: return correctErasures<17>(errors);
  169. case 18: return correctErasures<18>(errors);
  170. case 19: return correctErasures<19>(errors);
  171. case 20: return correctErasures<20>(errors);
  172. case 21: return correctErasures<21>(errors);
  173. case 22: return correctErasures<22>(errors);
  174. case 23: return correctErasures<23>(errors);
  175. case 24: return correctErasures<24>(errors);
  176. case 25: return correctErasures<25>(errors);
  177. case 26: return correctErasures<26>(errors);
  178. case 27: return correctErasures<27>(errors);
  179. case 28: return correctErasures<28>(errors);
  180. case 29: return correctErasures<29>(errors);
  181. case 30: return correctErasures<30>(errors);
  182. case 31: return correctErasures<31>(errors);
  183. case 32: return correctErasures<32>(errors);
  184. }
  185. return -errors.size(); //it's possible to correct more errors if the above switch were extended ...
  186. }
  187. //convenience function for when erasures aren't needed
  188. auto correctErrors() -> int {
  189. auto errors = calculateErrors();
  190. return correctErrata(errors);
  191. }
  192. };
  193. }