matrix.hpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #pragma once
  2. namespace nall {
  3. template<typename T, uint Rows, uint Cols>
  4. struct Matrix {
  5. static_assert(Rows > 0 && Cols > 0);
  6. Matrix() = default;
  7. Matrix(const Matrix&) = default;
  8. Matrix(const initializer_list<T>& source) {
  9. uint index = 0;
  10. for(auto& value : source) {
  11. if(index >= Rows * Cols) break;
  12. values[index / Cols][index % Cols] = value;
  13. }
  14. }
  15. operator array_span<T>() { return {values, Rows * Cols}; }
  16. operator array_view<T>() const { return {values, Rows * Cols}; }
  17. //1D matrices (for polynomials, etc)
  18. auto operator[](uint row) -> T& { return values[row][0]; }
  19. auto operator[](uint row) const -> T { return values[row][0]; }
  20. //2D matrices
  21. auto operator()(uint row, uint col) -> T& { return values[row][col]; }
  22. auto operator()(uint row, uint col) const -> T { return values[row][col]; }
  23. //operators
  24. auto operator+() const -> Matrix {
  25. Matrix result;
  26. for(uint row : range(Rows)) {
  27. for(uint col : range(Cols)) {
  28. result(row, col) = +target(row, col);
  29. }
  30. }
  31. return result;
  32. }
  33. auto operator-() const -> Matrix {
  34. Matrix result;
  35. for(uint row : range(Rows)) {
  36. for(uint col : range(Cols)) {
  37. result(row, col) = -target(row, col);
  38. }
  39. }
  40. return result;
  41. }
  42. auto operator+(const Matrix& source) const -> Matrix {
  43. Matrix result;
  44. for(uint row : range(Rows)) {
  45. for(uint col : range(Cols)) {
  46. result(row, col) = target(row, col) + source(row, col);
  47. }
  48. }
  49. return result;
  50. }
  51. auto operator-(const Matrix& source) const -> Matrix {
  52. Matrix result;
  53. for(uint row : range(Rows)) {
  54. for(uint col : range(Cols)) {
  55. result(row, col) = target(row, col) - source(row, col);
  56. }
  57. }
  58. return result;
  59. }
  60. auto operator*(T source) const -> Matrix {
  61. Matrix result;
  62. for(uint row : range(Rows)) {
  63. for(uint col : range(Cols)) {
  64. result(row, col) = target(row, col) * source;
  65. }
  66. }
  67. return result;
  68. }
  69. auto operator/(T source) const -> Matrix {
  70. Matrix result;
  71. for(uint row : range(Rows)) {
  72. for(uint col : range(Cols)) {
  73. result(row, col) = target(row, col) / source;
  74. }
  75. }
  76. return result;
  77. }
  78. //warning: matrix multiplication is not commutative!
  79. template<uint SourceRows, uint SourceCols>
  80. auto operator*(const Matrix<T, SourceRows, SourceCols>& source) const -> Matrix<T, Rows, SourceCols> {
  81. static_assert(Cols == SourceRows);
  82. Matrix<T, Rows, SourceCols> result;
  83. for(uint y : range(Rows)) {
  84. for(uint x : range(SourceCols)) {
  85. T sum{};
  86. for(uint z : range(Cols)) {
  87. sum += target(y, z) * source(z, x);
  88. }
  89. result(y, x) = sum;
  90. }
  91. }
  92. return result;
  93. }
  94. template<uint SourceRows, uint SourceCols>
  95. auto operator/(const Matrix<T, SourceRows, SourceCols>& source) const -> maybe<Matrix<T, Rows, SourceCols>> {
  96. static_assert(Cols == SourceRows && SourceRows == SourceCols);
  97. if(auto inverted = source.invert()) return operator*(inverted());
  98. return {};
  99. }
  100. auto& operator+=(const Matrix& source) { return *this = operator+(source); }
  101. auto& operator-=(const Matrix& source) { return *this = operator-(source); }
  102. auto& operator*=(T source) { return *this = operator*(source); }
  103. auto& operator/=(T source) { return *this = operator/(source); }
  104. template<uint SourceRows, uint SourceCols>
  105. auto& operator*=(const Matrix<T, SourceRows, SourceCols>& source) { return *this = operator*(source); }
  106. //matrix division is not always possible (when matrix cannot be inverted), so operator/= is not provided
  107. //algorithm: Gauss-Jordan
  108. auto invert() const -> maybe<Matrix> {
  109. static_assert(Rows == Cols);
  110. Matrix source = *this;
  111. Matrix result = identity();
  112. const auto add = [&](uint targetRow, uint sourceRow, T factor = 1) {
  113. for(uint col : range(Cols)) {
  114. result(targetRow, col) += result(sourceRow, col) * factor;
  115. source(targetRow, col) += source(sourceRow, col) * factor;
  116. }
  117. };
  118. const auto sub = [&](uint targetRow, uint sourceRow, T factor = 1) {
  119. for(uint col : range(Cols)) {
  120. result(targetRow, col) -= result(sourceRow, col) * factor;
  121. source(targetRow, col) -= source(sourceRow, col) * factor;
  122. }
  123. };
  124. const auto mul = [&](uint row, T factor) {
  125. for(uint col : range(Cols)) {
  126. result(row, col) *= factor;
  127. source(row, col) *= factor;
  128. }
  129. };
  130. for(uint i : range(Cols)) {
  131. if(source(i, i) == 0) {
  132. for(uint row : range(Rows)) {
  133. if(source(row, i) != 0) {
  134. add(i, row);
  135. break;
  136. }
  137. }
  138. //matrix is not invertible:
  139. if(source(i, i) == 0) return {};
  140. }
  141. mul(i, T{1} / source(i, i));
  142. for(uint row : range(Rows)) {
  143. if(row == i) continue;
  144. sub(row, i, source(row, i));
  145. }
  146. }
  147. return result;
  148. }
  149. auto transpose() const -> Matrix<T, Cols, Rows> {
  150. Matrix<T, Cols, Rows> result;
  151. for(uint row : range(Rows)) {
  152. for(uint col : range(Cols)) {
  153. result(col, row) = target(row, col);
  154. }
  155. }
  156. return result;
  157. }
  158. static auto identity() -> Matrix {
  159. static_assert(Rows == Cols);
  160. Matrix result;
  161. for(uint row : range(Rows)) {
  162. for(uint col : range(Cols)) {
  163. result(row, col) = row == col;
  164. }
  165. }
  166. return result;
  167. }
  168. //debugging function: do not use in production code
  169. template<uint Pad = 0>
  170. auto _print() const -> void {
  171. for(uint row : range(Rows)) {
  172. nall::print("[ ");
  173. for(uint col : range(Cols)) {
  174. nall::print(pad(target(row, col), Pad, ' '), " ");
  175. }
  176. nall::print("]\n");
  177. }
  178. }
  179. protected:
  180. //same as operator(), but with easier to read syntax inside Matrix class
  181. auto target(uint row, uint col) -> T& { return values[row][col]; }
  182. auto target(uint row, uint col) const -> T { return values[row][col]; }
  183. T values[Rows][Cols]{};
  184. };
  185. }