agreement.cc 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/examples - Demo shape agreement rules
  3. // (c) Daniel Llorens - 2015-2016
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #include "ra/ra.hh"
  9. #include <iostream>
  10. using std::cout, std::endl;
  11. int main()
  12. {
  13. // The general shape agreement rule is 'prefix agreement': all the first
  14. // dimensions must match, all the second dimensions, etc. If some arguments
  15. // have lower rank than others, then the missing dimensions are ignored.
  16. // For example:
  17. ra::Big<float, 3> A({3, 4, 5}, 1.);
  18. ra::Big<float, 2> B({3, 4}, 2.);
  19. ra::Big<float, 1> C({3}, 3.);
  20. ra::Big<float, 3> X({3, 4, 5}, 99.);
  21. // In the following expression, the shapes of the arguments are:
  22. // A: [3 4 5]
  23. // B: [3 4]
  24. // C: [3]
  25. // X: [3 4 5] (taken from the shape of the right hand side)
  26. // All the first dimensions are 3, all the second dimensions are 4, and all
  27. // the third dimensions are 5, so the expression is valid.
  28. // Note that the agreement rules are applied equally to the destination argument.
  29. X = map([](auto && a, auto && b, auto && c) { return a+b-c; }, A, B, C);
  30. cout << "\nX: " << X << endl;
  31. // (you can write the expression above as X = A+B-C).
  32. // This rule comes from the array language J (for function rank 0; see J's
  33. // documentation). Obvious examples include:
  34. {
  35. // multiply any array by a scalar. The shape of a scalar is [];
  36. // therefore, a scalar agrees with anything.
  37. ra::Big<float, 2> X = B*7.;
  38. cout << "\nB*7: " << X << endl;
  39. }
  40. {
  41. // multiply each row of B by a different element of C, X(i, j) = B(i, j)*C(i)
  42. ra::Big<float, 2> X = B*C;
  43. cout << "\nB*C: " << X << endl;
  44. }
  45. {
  46. // multiply arrays componentwise (identical shapes agree).
  47. ra::Big<float, 2> X = B*B;
  48. cout << "\nB*B: " << X << endl;
  49. }
  50. // Some special expressions, such as tensor indices, do not have a
  51. // shape. Therefore they need to be accompanied by some other expression
  52. // that does have a shape, or the overall expression is not valid.
  53. {
  54. constexpr auto i = ra::iota<0>();
  55. constexpr auto j = ra::iota<1>();
  56. // That's why you can do
  57. ra::Big<float, 2> X({3, 4}, i-j);
  58. cout << "\ni-j: " << X << endl;
  59. // but the following would be invalid:
  60. // ra::Big<float, 2> X = i-j; // no shape to construct X with
  61. }
  62. // Axis insertion lets you match arguments more flexibly than simple prefix matching.
  63. {
  64. ra::Big<float, 2> A({3, 4}, 0);
  65. ra::Big<float, 1> b({3}, ra::_0);
  66. ra::Big<float, 1> c({4}, ra::_0);
  67. // Compare:
  68. // [3 4] matches [3] - normal prefix matching. Assign b(i) to A(i, ...)
  69. A = b;
  70. cout << "\nA: " << A << endl;
  71. // [3 4] matches [X 4] - skip 1 dimension when matching. Assign c(i) to A(..., i)
  72. A = c(ra::insert<1>);
  73. cout << "\nA: " << A << endl;
  74. }
  75. return 0;
  76. }