matrix.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #ifndef __matrix_h__
  2. #define __matrix_h__
  3. #include <stdlib.h>
  4. #include <math.h>
  5. #include "vector.h"
  6. class Matrix
  7. {
  8. friend class Matrix3d;
  9. friend ostream& operator<<(ostream&, const Matrix&);
  10. friend istream& operator>>(istream&, Matrix&);
  11. friend Matrix operator*(double, const Matrix&);
  12. friend double max(const Matrix&);
  13. friend double min(const Matrix&);
  14. friend double norm(const Matrix&);
  15. friend double norm1(const Matrix&);
  16. friend double norm2(const Matrix&);
  17. friend double norm2_sqv(const Matrix&);
  18. friend void del(Matrix&);
  19. public:
  20. Matrix(int n=0, int m=0, double = 0.0);
  21. Matrix(const Matrix&);
  22. Matrix(const Vector&);
  23. Matrix& operator=(const Matrix&);
  24. ~Matrix();
  25. void free();
  26. int dim_i() const { return d_i; }
  27. int dim_j() const { return d_j; }
  28. Vector& row(int) const;
  29. Vector col(int) const;
  30. Vector diag() const;
  31. Matrix trans() const;
  32. Matrix abs() const;
  33. Matrix& resize(int,int);
  34. void swap_rows(int,int);
  35. operator Vector() const;
  36. int operator==(const Matrix&) const;
  37. int operator!=(const Matrix& x) const { return !(*this == x); }
  38. Vector& operator[](int i) const { return row(i); }
  39. double& operator()(int, int);
  40. double& operator()(int, int) const;
  41. Matrix operator+(const Matrix&);
  42. Matrix operator-(const Matrix&);
  43. Matrix operator-();
  44. Matrix& operator+=(const Matrix&);
  45. Matrix& operator-=(const Matrix&);
  46. Matrix& operator/=(double);
  47. Matrix& operator*=(double);
  48. Matrix operator*(double);
  49. Matrix operator*(double) const;
  50. Matrix operator*(const Matrix&);
  51. Vector operator*(const Vector& v) { return Vector(*this * Matrix(v)); }
  52. Matrix operator/(double);
  53. Matrix operator^(const Matrix&) const;
  54. protected:
  55. double& elem(int i, int j) const { return v[i]->v[j]; }
  56. private:
  57. Vector** v;
  58. int d_i;
  59. int d_j;
  60. void flip_rows(int,int);
  61. void check_dimensions(const Matrix&) const;
  62. };
  63. inline Matrix& Matrix::operator=(const Matrix& mat) {
  64. int i, j;
  65. if (d_i != mat.d_i || d_j != mat.d_j) {
  66. for(i = 0; i < d_i; i++) delete v[i];
  67. delete v;
  68. d_i = mat.d_i;
  69. d_j = mat.d_j;
  70. v = new Vector*[d_i];
  71. for(i = 0; i < d_i; i++) v[i] = new Vector(d_j);
  72. }
  73. for(i = 0; i < d_i; i++)
  74. for(j = 0; j < d_j; j++) elem(i,j) = mat.elem(i,j);
  75. return (*this);
  76. }
  77. inline int Matrix::operator==(const Matrix& x) const {
  78. int i, j;
  79. if (d_i != x.d_i || d_j != x.d_j) return (0);
  80. for(i = 0; i < d_i; i++)
  81. for(j = 0; j < d_j; j++)
  82. if (elem(i,j) != x.elem(i,j)) return (0);
  83. return (1);
  84. }
  85. inline Vector& Matrix::row(int i) const {
  86. if (i < 0 || i >= d_i) {
  87. cerr << "Matrix: row index out of range" << endl;
  88. exit(1);
  89. }
  90. return (*v[i]);
  91. }
  92. inline double& Matrix::operator()(int i, int j) {
  93. if (i < 0 || i >= d_i) {
  94. cerr << "Matrix: row index out of range" << endl;
  95. exit(1);
  96. }
  97. if (j < 0 || j >= d_j) {
  98. cerr << "Matrix: col index out of range" << endl;
  99. exit(1);
  100. }
  101. return (elem(i,j));
  102. }
  103. inline double& Matrix::operator()(int i, int j) const {
  104. if (i < 0 || i >= d_i) {
  105. cerr << "Matrix: row index out of range" << endl;
  106. exit(1);
  107. }
  108. if (j < 0 || j >= d_j) {
  109. cerr << "Matrix: col index out of range" << endl;
  110. exit(1);
  111. }
  112. return (elem(i,j));
  113. }
  114. inline Vector Matrix::col(int i) const {
  115. if (i < 0 || i >= d_j) {
  116. cerr << "Matrix: col index out of range" << endl;
  117. exit(1);
  118. }
  119. Vector result(d_i);
  120. int j = d_i;
  121. while (j--) result.v[j] = elem(j,i);
  122. return (result);
  123. }
  124. inline Vector Matrix::diag() const {
  125. if (d_i != d_j) {
  126. cerr << "Matrix: diag defined only if d_i = d_j" << endl;
  127. exit(1);
  128. }
  129. Vector result(d_i);
  130. int j = d_i;
  131. while (j--) result.v[j] = elem(j,j);
  132. return (result);
  133. }
  134. inline Matrix::operator Vector() const {
  135. if (d_j != 1) {
  136. cerr << "error: cannot make vector from matrix" << endl;
  137. exit(1);
  138. }
  139. return (col(0));
  140. }
  141. inline Matrix Matrix::operator+(const Matrix& mat) {
  142. int i, j;
  143. check_dimensions(mat);
  144. Matrix result(d_i,d_j);
  145. for(i = 0; i < d_i; i++)
  146. for(j = 0; j < d_j; j++)
  147. result.elem(i,j) = elem(i,j) + mat.elem(i,j);
  148. return (result);
  149. }
  150. inline Matrix Matrix::operator-(const Matrix& mat) {
  151. int i, j;
  152. check_dimensions(mat);
  153. Matrix result(d_i,d_j);
  154. for(i = 0; i < d_i; i++)
  155. for(j = 0; j < d_j; j++)
  156. result.elem(i,j) = elem(i,j) - mat.elem(i,j);
  157. return (result);
  158. }
  159. inline Matrix Matrix::operator-( ) {
  160. int i, j;
  161. Matrix result(d_i,d_j);
  162. for(i = 0; i < d_i; i++)
  163. for(j= 0; j < d_j; j++)
  164. result.elem(i,j) = -elem(i,j);
  165. return (result);
  166. }
  167. inline Matrix& Matrix::operator+=(const Matrix& mat) {
  168. int i, j;
  169. check_dimensions(mat);
  170. for(i = 0; i < d_i; i++)
  171. for(j = 0; j < d_j; j++)
  172. elem(i,j) += mat.elem(i,j);
  173. return (*this);
  174. }
  175. inline Matrix& Matrix::operator-=(const Matrix& mat) {
  176. int i, j;
  177. check_dimensions(mat);
  178. for(i = 0; i < d_i; i++)
  179. for(j = 0; j < d_j; j++)
  180. elem(i,j) -= mat.elem(i,j);
  181. return (*this);
  182. }
  183. inline Matrix& Matrix::operator/=(double x) {
  184. int i, j;
  185. if (x == 0) {
  186. cerr << "Matrix/=: divided by zero" << endl;
  187. exit(1);
  188. }
  189. for(i = 0; i < d_i; i++)
  190. for(j = 0; j < d_j; j++)
  191. elem(i,j) /= x;
  192. return (*this);
  193. }
  194. inline Matrix& Matrix::operator*=(double x) {
  195. int i, j;
  196. for(i = 0; i < d_i; i++)
  197. for(j = 0; j < d_j; j++)
  198. elem(i,j) *= x;
  199. return (*this);
  200. }
  201. inline Matrix Matrix::operator*(double f) {
  202. int i, j;
  203. Matrix result(d_i,d_j);
  204. for(i = 0; i < d_i; i++)
  205. for(j = 0; j < d_j; j++)
  206. result.elem(i,j) = elem(i,j) * f;
  207. return (result);
  208. }
  209. inline Matrix Matrix::operator*(double f) const {
  210. int i, j;
  211. Matrix result(d_i,d_j);
  212. for(i = 0; i < d_i; i++)
  213. for(j = 0; j < d_j; j++)
  214. result.elem(i,j) = elem(i,j) * f;
  215. return (result);
  216. }
  217. inline Matrix Matrix::operator*(const Matrix& mat) {
  218. if (d_j != mat.d_i) {
  219. cerr << "matrix multiplication: incompatible matrix types" << endl;
  220. exit(1);
  221. }
  222. Matrix result(d_i, mat.d_j);
  223. int i,j;
  224. for (i = 0; i < mat.d_j; i++)
  225. for (j = 0; j < d_i; j++) result.elem(j,i) = *v[j] * mat.col(i);
  226. return (result);
  227. }
  228. inline Matrix Matrix::operator/(double a) {
  229. if(a == 0) {
  230. cerr << "Matrix: divided by zero" << endl;
  231. exit(1);
  232. }
  233. a = 1. / a;
  234. return (Matrix(*this *a ));
  235. }
  236. inline Matrix Matrix::operator^(const Matrix& mat) const {
  237. if (d_i != mat.d_i || d_j != mat.d_j) {
  238. cerr << "Matrix=: matrixes have different sizes" << endl;
  239. exit(1);
  240. }
  241. Matrix res(d_i,d_j);
  242. for(int i = 0; i < d_i; i++)
  243. for(int j = 0; j < d_j; j++)
  244. res.elem(i,j) = elem(i,j)*mat.elem(i,j);
  245. return (res);
  246. }
  247. inline Matrix Matrix::abs() const {
  248. Matrix result(d_i,d_j);
  249. for(int i = 0; i < d_i; i++)
  250. for(int j = 0; j < d_j; j++)
  251. result(i,j) = fabs(elem(i,j));
  252. return (result);
  253. }
  254. inline void Print(const Matrix& m, ostream& out=cout) { out << m; }
  255. inline void Read(Matrix& m, istream& in=cin) { in >> m; }
  256. #endif