vector.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. #include "simple/geom/vector.hpp"
  2. #include "simple/geom/bool_algebra.hpp"
  3. #include <fstream>
  4. #include <cassert>
  5. using namespace simple;
  6. using geom::vector;
  7. using int2 = vector<int,2>;
  8. using int3 = vector<int,3>;
  9. using int2x2 = vector<int2, 2>;
  10. using int3x3 = vector<int3, 3>;
  11. template <typename Vector>
  12. std::istream& operator>>(std::istream& is, Vector& v)
  13. {
  14. for(auto&& c : v)
  15. if(not (is >> c)) break;
  16. return is;
  17. }
  18. void SquareMatrixMultiplication()
  19. {
  20. std::vector<int3x3> matrices;
  21. std::ifstream test_data("square_matrix.data");
  22. int3x3 matrix;
  23. while(test_data >> matrix)
  24. matrices.push_back(matrix);
  25. assert(matrices.size() > 3);
  26. assert(matrices.size() % 3 == 0);
  27. for(auto i = matrices.begin(); i != matrices.end(); i+=3)
  28. {
  29. int3x3 A = *i;
  30. int3x3 B = *(i+1);
  31. int3x3 AxB = *(i+2);
  32. assert(( B(A) == AxB ));
  33. }
  34. }
  35. void MatrixVectorMultiplication()
  36. {
  37. struct test_case
  38. {
  39. int3x3 matrix;
  40. int3 in, out;
  41. };
  42. std::vector<test_case> tests;
  43. std::ifstream test_data("matrix_vector.data");
  44. while(test_data)
  45. {
  46. test_case test;
  47. test_data >> test.in;
  48. test_data >> test.matrix;
  49. test_data >> test.out;
  50. tests.push_back(test);
  51. }
  52. assert(tests.size() > 2);
  53. tests.pop_back();
  54. for(auto&& [matrix, in, out] : tests)
  55. assert( out == matrix(in) );
  56. }
  57. void DotProduct()
  58. {
  59. struct test_case
  60. {
  61. int3 in1, in2;
  62. int out;
  63. };
  64. std::vector<test_case> tests;
  65. std::ifstream test_data("dot_product.data");
  66. while(test_data)
  67. {
  68. test_case test;
  69. test_data >> test.in1;
  70. test_data >> test.in2;
  71. test_data >> test.out;
  72. tests.push_back(test);
  73. }
  74. assert(tests.size() > 2);
  75. tests.pop_back();
  76. for(auto&& [in1, in2, out] : tests)
  77. {
  78. assert( out == in1(in2) );
  79. assert( out == in2(in1) );
  80. }
  81. }
  82. void NonSquareMatrixMultiplication()
  83. {
  84. using int2x3 = vector<int2, 3>;
  85. using int3x2 = vector<int3, 2>;
  86. using int3x5 = vector<int3, 5>;
  87. using int2x5 = vector<int2, 5>;
  88. int2x3 a{ int2x3::array {{
  89. {1, 2},
  90. {2, 1},
  91. {1, 2},
  92. }}};
  93. int3x5 b{ int3x5::array {{
  94. {1, 2, 3},
  95. {3, 1, 2},
  96. {2, 3, 1},
  97. {3, 2, 1},
  98. {1, 3, 2}
  99. }}};
  100. int2x5 ans{ int2x5::array {{
  101. {8, 10},
  102. {7, 11},
  103. {9, 9},
  104. {8, 10},
  105. {9, 9}
  106. }}};
  107. assert ( ans == a(b) );
  108. struct test_case
  109. {
  110. int3x2 in1;
  111. int2x3 in2;
  112. int2x2 out;
  113. };
  114. std::vector<test_case> tests;
  115. std::ifstream test_data("matrix.data");
  116. while(test_data)
  117. {
  118. test_case test;
  119. test_data >> test.in1;
  120. test_data >> test.in2;
  121. test_data >> test.out;
  122. tests.push_back(test);
  123. }
  124. assert(tests.size() > 2);
  125. tests.pop_back();
  126. for(auto&& [in1, in2, out] : tests)
  127. assert( out == in2(in1) );
  128. }
  129. // TODO: all the other ops -_-
  130. void RowColumnVectorAndMatrix()
  131. {
  132. const vector row(0.1f, 0.2f, 0.3f);
  133. auto matrix = vector {
  134. vector(1.0f, 2.0f, 3.0f),
  135. vector(4.0f, 5.0f, 6.0f),
  136. vector(7.0f, 8.0f, 9.0f),
  137. };
  138. assert(( matrix + row ==
  139. vector{
  140. vector(1.1f, 2.2f, 3.3f),
  141. vector(4.1f, 5.2f, 6.3f),
  142. vector(7.1f, 8.2f, 9.3f),
  143. }
  144. ));
  145. assert(( row + matrix ==
  146. vector{
  147. vector(1.1f, 2.2f, 3.3f),
  148. vector(4.1f, 5.2f, 6.3f),
  149. vector(7.1f, 8.2f, 9.3f),
  150. }
  151. ));
  152. matrix += row;
  153. assert(( matrix ==
  154. vector{
  155. vector(1.1f, 2.2f, 3.3f),
  156. vector(4.1f, 5.2f, 6.3f),
  157. vector(7.1f, 8.2f, 9.3f),
  158. }
  159. ));
  160. const vector column{
  161. vector(0.1f),
  162. vector(0.2f),
  163. vector(0.3f),
  164. };
  165. matrix = vector {
  166. vector(1.0f, 2.0f, 3.0f),
  167. vector(4.0f, 5.0f, 6.0f),
  168. vector(7.0f, 8.0f, 9.0f),
  169. };
  170. assert(( matrix + column ==
  171. vector{
  172. vector(1.1f, 2.1f, 3.1f),
  173. vector(4.2f, 5.2f, 6.2f),
  174. vector(7.3f, 8.3f, 9.3f),
  175. }
  176. ));
  177. assert(( column + matrix ==
  178. vector{
  179. vector(1.1f, 2.1f, 3.1f),
  180. vector(4.2f, 5.2f, 6.2f),
  181. vector(7.3f, 8.3f, 9.3f),
  182. }
  183. ));
  184. matrix += column;
  185. assert(( matrix ==
  186. vector{
  187. vector(1.1f, 2.1f, 3.1f),
  188. vector(4.2f, 5.2f, 6.2f),
  189. vector(7.3f, 8.3f, 9.3f),
  190. }
  191. ));
  192. assert
  193. (
  194. vector
  195. (
  196. vector(10),
  197. vector(20),
  198. vector(30)
  199. )
  200. +
  201. vector(1,2,3)
  202. ==
  203. vector
  204. (
  205. vector(11, 12, 13),
  206. vector(21, 22, 23),
  207. vector(31, 32, 33)
  208. )
  209. );
  210. assert
  211. (
  212. vector(1,2,3)
  213. +
  214. vector
  215. (
  216. vector(10),
  217. vector(20),
  218. vector(30)
  219. )
  220. ==
  221. vector
  222. (
  223. vector(11, 12, 13),
  224. vector(21, 22, 23),
  225. vector(31, 32, 33)
  226. )
  227. );
  228. }
  229. void PolynomialMultiplication()
  230. {
  231. const vector p1(1, -1, 3, 2);
  232. const vector p2{
  233. vector(4),
  234. vector(2),
  235. vector(1),
  236. vector(-5),
  237. };
  238. // get a matrix all combination
  239. const auto all_combos = p1 * p2;
  240. constexpr auto degree = std::max(p1.dimensions, p2.dimensions);
  241. auto result = vector<int, degree + degree - 1>{};
  242. // sum the secondary diagonals of the matrix
  243. constexpr size_t x = -1;
  244. result += all_combos[0].mix<0,1,2,3,x,x,x>(0);
  245. result += all_combos[1].mix<x,0,1,2,3,x,x>(0);
  246. result += all_combos[2].mix<x,x,0,1,2,3,x>(0);
  247. result += all_combos[3].mix<x,x,x,0,1,2,3>(0);
  248. assert(result == vector(4,-2,11,8,12,-13,-10));
  249. }
  250. void StructredBinding()
  251. {
  252. {
  253. vector v(1,2,3);
  254. auto [a,b,c] = v;
  255. assert( a == 1);
  256. assert( b == 2);
  257. assert( c == 3);
  258. }
  259. {
  260. vector v(1,2,3);
  261. const auto [a,b,c] = v;
  262. assert( a == 1);
  263. assert( b == 2);
  264. assert( c == 3);
  265. }
  266. }
  267. constexpr bool Constexprness()
  268. {
  269. constexpr int3x3 A{}, B{};
  270. constexpr int3 a{}, b{};
  271. void(A(B)); void(B(A)); void(A(a)); void(B(a)); void(a(b));
  272. return true;
  273. }
  274. int main()
  275. {
  276. SquareMatrixMultiplication();
  277. MatrixVectorMultiplication();
  278. DotProduct();
  279. NonSquareMatrixMultiplication();
  280. RowColumnVectorAndMatrix();
  281. PolynomialMultiplication();
  282. static_assert(Constexprness());
  283. return 0;
  284. }