matrix.c 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. #include "matrix.h"
  2. #include <assert.h>
  3. #include <math.h>
  4. #include <stdlib.h> // malloc
  5. #include <string.h> // memset
  6. #define M(m, _c, _r) m->data[(_c) + (_r) * m->c]
  7. sti_matrix* sti_matrix_new(int c, int r) {
  8. sti_matrix* mat;
  9. mat = malloc(sizeof(*mat) + sizeof(mat->data[0]) * r * c);
  10. mat->r = r;
  11. mat->c = c;
  12. return mat;
  13. }
  14. sti_matrix* sti_matrix_same_size(sti_matrix* m) {
  15. return sti_matrix_new(m->c, m->r);
  16. }
  17. sti_matrix* sti_matrix_size_for_mul(sti_matrix* a, sti_matrix* b) {
  18. return sti_matrix_new(b->c, a->r);
  19. }
  20. sti_matrix* sti_matrix_copy(sti_matrix* m) {
  21. sti_matrix* mat = sti_matrix_same_size(m);
  22. memcpy(mat->data, m->data, sizeof(*mat->data) * m->r * m->c);
  23. return mat;
  24. }
  25. // careful here...
  26. void sti_matrix_print(sti_matrix* m, FILE* f) {
  27. for(long r = 0; r < m->r; r++) {
  28. for(long c = 0; c < m->c; c++) {
  29. fprintf(f, "%.2f ", m->data[c + m->c * r]);
  30. }
  31. fprintf(f, "\n");
  32. }
  33. }
  34. void sti_matrix_clear(sti_matrix* m) {
  35. memset(m->data, 0, sizeof(m->data) * m->c * m->r);
  36. }
  37. void sti_matrix_set(sti_matrix* m, float v) {
  38. if(v == 0) {
  39. memset(m->data, 0, sizeof(m->data) * m->c * m->r);
  40. return;
  41. }
  42. long sz = m->c * m->r;
  43. for(int i = 0; i < sz; i++) {
  44. m->data[i] = v;
  45. }
  46. }
  47. void sti_matrix_load(sti_matrix* m, float* v) {
  48. memcpy(m->data, v, sizeof(m->data[0]) * m->c * m->r);
  49. }
  50. void sti_matrix_ident(sti_matrix* m) {
  51. for(int i = 0; i < m->c; i++)
  52. for(int j = 0; j < m->r; j++) {
  53. m->data[i + j * m->c] = i == j;
  54. }
  55. }
  56. void sti_matrix_rand(sti_matrix* m, float min, float max) {
  57. long len = m->c * m->r;
  58. float sz = max - min;
  59. for(long n = 0; n < len; n++) {
  60. float x = ((float)rand() * sz) / (float)RAND_MAX;
  61. m->data[n] = min + x;
  62. }
  63. }
  64. void sti_matrix_transpose(sti_matrix* a, sti_matrix* out) {
  65. assert(a->c * a->r <= out->c * out->r);
  66. out->r = a->c;
  67. out->c = a->r;
  68. for(int r = 0; r < a->r; r++)
  69. for(int c = r; c < a->c; c++) {
  70. float tmp;
  71. if(c < a->c) tmp = M(a, c, r);
  72. if(c < out->c) M(out, c, r) = M(a, r, c);
  73. if(c < a->c) M(out, r, c) = tmp;
  74. }
  75. }
  76. int sti_matrix_eq(sti_matrix* a, sti_matrix* b) {
  77. if(a->r != b->r || a->c != b->c) return 0;
  78. long len = a->c * a->r;
  79. for(long n = 0; n < len; n++) {
  80. if(a->data[n] != b->data[n]) return 0;
  81. }
  82. return 1;
  83. }
  84. sti_matrix* sti_matrix_mul(sti_matrix* a, sti_matrix* b) {
  85. sti_matrix* o;
  86. if(a->c != b->r) return NULL;
  87. o = sti_matrix_new(b->c, a->r);
  88. sti_matrix_mulp(a, b, o);
  89. return o;
  90. }
  91. // no checks for size match.
  92. void sti_matrix_mulp(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  93. long klim = a->c;
  94. for(int c = 0; c < b->c; c++)
  95. for(int r = 0; r < a->r; r++) {
  96. M(out, c, r) = 0;
  97. for(int k = 0; k < klim; k++) {
  98. M(out, c, r) += M(a, k, r) * M(b, c, k);
  99. }
  100. }
  101. }
  102. // multiplies a with the transpose of b
  103. sti_matrix* sti_matrix_mul_transb(sti_matrix* a, sti_matrix* b) {
  104. sti_matrix* o;
  105. if(a->c != b->c) return NULL;
  106. o = sti_matrix_new(b->r, a->r);
  107. sti_matrix_mulp_transb(a, b, o);
  108. return o;
  109. }
  110. void sti_matrix_mulp_transb(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  111. long klim = a->c;
  112. for(int c = 0; c < b->r; c++)
  113. for(int r = 0; r < a->r; r++) {
  114. M(out, c, r) = 0;
  115. for(int k = 0; k < klim; k++) {
  116. M(out, c, r) += M(a, k, r) * M(b, k, c);
  117. }
  118. }
  119. }
  120. #ifndef MIN
  121. #define MIN(a, b) (a < b ? a : b)
  122. #endif
  123. void sti_matrix_add(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  124. int c = MIN(out->c, MIN(a->c, b->c));
  125. int r = MIN(out->r, MIN(a->r, b->r));
  126. for(int j = 0; j < r; j++)
  127. for(int i = 0; i < c; i++) {
  128. M(out, i, j) = M(a, i, j) + M(a, i, j);
  129. }
  130. }
  131. void sti_matrix_sub(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  132. int c = MIN(out->c, MIN(a->c, b->c));
  133. int r = MIN(out->r, MIN(a->r, b->r));
  134. for(int j = 0; j < r; j++)
  135. for(int i = 0; i < c; i++) {
  136. M(out, i, j) = M(a, i, j) - M(a, i, j);
  137. }
  138. }
  139. void sti_matrix_scalar_mul(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  140. long sz = a->c * a->r;
  141. for(int i = 0; i < sz; i++) {
  142. out->data[i] = a->data[i] * b->data[i];
  143. }
  144. }
  145. void sti_matrix_scale(sti_matrix* a, float s, sti_matrix* out) {
  146. long sz = a->c * a->r;
  147. for(int i = 0; i < sz; i++) {
  148. out->data[i] = a->data[i] * s;
  149. }
  150. }
  151. // apply e^a[n]
  152. void sti_matrix_exp(sti_matrix* a, sti_matrix* out) {
  153. long sz = a->c * a->r;
  154. for(int i = 0; i < sz; i++) {
  155. out->data[i] = expf(a->data[i]);
  156. }
  157. }
  158. // simple flat sum of all values in the matrix
  159. float sti_matrix_sum(sti_matrix* a) {
  160. long sz = a->c * a->r;
  161. float sum = 0;
  162. for(int i = 0; i < sz; i++) {
  163. sum += a->data[i];
  164. }
  165. return sum;
  166. }
  167. void sti_matrix_softmax(sti_matrix* a, sti_matrix* out) {
  168. long sz = a->c * a->r;
  169. float sum = 0;
  170. for(int i = 0; i < sz; i++) {
  171. out->data[i] = expf(a->data[i]);
  172. sum += out->data[i];
  173. }
  174. float invsum = 1.0 / sum;
  175. for(int i = 0; i < sz; i++) {
  176. out->data[i] *= invsum;
  177. }
  178. }
  179. void sti_matrix_min(sti_matrix* a, float minval, sti_matrix* out) {
  180. long sz = a->c * a->r;
  181. for(int i = 0; i < sz; i++) {
  182. out->data[i] = fminf(a->data[i], minval);
  183. }
  184. }
  185. void sti_matrix_max(sti_matrix* a, float maxval, sti_matrix* out) {
  186. long sz = a->c * a->r;
  187. for(int i = 0; i < sz; i++) {
  188. out->data[i] = fmaxf(a->data[i], maxval);
  189. }
  190. }
  191. void sti_matrix_clamp(sti_matrix* a, float minval, float maxval, sti_matrix* out) {
  192. long sz = a->c * a->r;
  193. for(int i = 0; i < sz; i++) {
  194. out->data[i] = fminf(minval, fmaxf(a->data[i], maxval));
  195. }
  196. }
  197. void sti_matrix_tanh_clamp(sti_matrix* a, sti_matrix* out) {
  198. long sz = a->c * a->r;
  199. for(int i = 0; i < sz; i++) {
  200. out->data[i] = tanhf(a->data[i]);
  201. }
  202. }
  203. void sti_matrix_relu_0(sti_matrix* a, sti_matrix* out) {
  204. long sz = a->c * a->r;
  205. for(int i = 0; i < sz; i++) {
  206. out->data[i] = fmax(0, a->data[i]);
  207. }
  208. }
  209. void sti_matrix_relu_half(sti_matrix* a, sti_matrix* out) {
  210. long sz = a->c * a->r;
  211. for(int i = 0; i < sz; i++) {
  212. out->data[i] = fmax(0, a->data[i] - .5f) + .5f;
  213. }
  214. }
  215. void sti_matrix_relu_n(sti_matrix* a, float n, sti_matrix* out) {
  216. long sz = a->c * a->r;
  217. for(int i = 0; i < sz; i++) {
  218. out->data[i] = fmax(0, a->data[i] - n) + n;
  219. }
  220. }
  221. void sti_matrix_silu(sti_matrix* a, sti_matrix* out) {
  222. long sz = a->c * a->r;
  223. for(int i = 0; i < sz; i++) {
  224. out->data[i] = a->data[i] / (1.f + expf(-a->data[i]));
  225. }
  226. }
  227. // mean squared error: SUM( (a - b)^2 )
  228. float sti_matrix_mse(sti_matrix* a, sti_matrix* b) {
  229. long sz = a->c * a->r;
  230. float sum = 0;
  231. for(int i = 0; i < sz; i++) {
  232. float x = a->data[i] - b->data[i];
  233. sum += x * x;
  234. }
  235. return sum / sz;
  236. }