vq_train.c 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. /*Daala video codec
  2. Copyright (c) 2012-2014 Daala project contributors. All rights reserved.
  3. Redistribution and use in source and binary forms, with or without
  4. modification, are permitted provided that the following conditions are met:
  5. - Redistributions of source code must retain the above copyright notice, this
  6. list of conditions and the following disclaimer.
  7. - Redistributions in binary form must reproduce the above copyright notice,
  8. this list of conditions and the following disclaimer in the documentation
  9. and/or other materials provided with the distribution.
  10. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  11. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  12. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  13. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
  14. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  15. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  16. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  17. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  18. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  19. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.*/
  20. #include <stdlib.h>
  21. #include <stdio.h>
  22. #include <math.h>
  23. #include <time.h>
  24. #include "od_defs.h"
  25. #include "../src/dct.h"
  26. #define MAX(a,b) ((a)>(b)?(a):(b))
  27. #define MAX_ENTRIES (4096)
  28. #define MAX_DIMS (128)
  29. #if 0
  30. # undef NUM_PROCS
  31. # define NUM_PROCS (1)
  32. #endif
  33. static double inner_prod(const double *x, const double *y, int n) {
  34. double sum;
  35. int i;
  36. sum = 0;
  37. for (i = 0; i < n; i++) sum += x[i]*y[i];
  38. return sum;
  39. }
  40. static void normalize(double *x, int n) {
  41. int i;
  42. double sum;
  43. sum = 1e-30;
  44. for (i = 0; i < n; i++) sum += x[i]*x[i];
  45. sum = 1./sqrt(sum);
  46. for (i = 0; i < n; i++) x[i] *= sum;
  47. }
  48. /* Returns the distance to the closest K=2 codeword. We can take a shortcut
  49. because there are only two possibilities: both pulses at the position with
  50. largest magnitude, or one pulse at each of the two largest magnitudes. */
  51. static double pvq_dist_k2(const double *data, int n) {
  52. double xbest1;
  53. double xbest2;
  54. int i;
  55. xbest1 = xbest2 = -1;
  56. for (i = 0; i < n; i++) {
  57. if (fabs(data[i]) > xbest2) {
  58. if (fabs(data[i]) > xbest1) {
  59. xbest2 = xbest1;
  60. xbest1 = fabs(data[i]);
  61. }
  62. else {
  63. xbest2 = fabs(data[i]);
  64. }
  65. }
  66. }
  67. return 2 - 2*MAX(xbest1, M_SQRT1_2*(xbest1 + xbest2));
  68. }
  69. static int find_nearest(const double *data, const double *codebook, int nb_entries,
  70. int n, double *sign, double *err) {
  71. double best_dist;
  72. double best_sign;
  73. int best_id;
  74. int i;
  75. best_dist = -1;
  76. best_id = 0;
  77. best_sign = 1;
  78. for (i = 0; i < nb_entries; i++) {
  79. double dist;
  80. dist = inner_prod(data, &codebook[i*n], n);
  81. if (fabs(dist) > best_dist) {
  82. best_dist = fabs(dist);
  83. best_sign = dist > 0 ? 1 : -1;
  84. best_id = i;
  85. }
  86. }
  87. if (sign) *sign = best_sign;
  88. if (err) *err = 2 - 2*best_dist;
  89. return best_id;
  90. }
  91. void vq_rand_init(const double *data, int nb_vectors, double *codebook,
  92. int nb_entries, int n) {
  93. int i;
  94. int j;
  95. /* Start with a codebook made of randomly selected vectors. */
  96. for (i = 0; i < nb_entries; i++) {
  97. int id;
  98. id = rand()%nb_vectors;
  99. for (j = 0; j < n; j++) {
  100. /* Add some noise just in case we pick the same vector twice. */
  101. codebook[i*n + j] = data[id*n + j] + .01*(rand()%3 - 1);
  102. }
  103. normalize(&codebook[i*n], n);
  104. }
  105. }
  106. double vq_train(const double *data, int nb_vectors, double *codebook,
  107. int nb_entries, int n, int nb_iter, int exclude_pvq) {
  108. int i;
  109. int iter;
  110. double rms[NUM_PROCS];
  111. double *accum;
  112. accum = (double *)malloc(MAX_ENTRIES*MAX_DIMS*NUM_PROCS*sizeof(*accum));
  113. for (iter = 0; iter < nb_iter; iter++) {
  114. for (i = 0; i < NUM_PROCS; i++) rms[i] = 0;
  115. memset(accum,0,nb_entries*n*NUM_PROCS*sizeof(*accum));
  116. #pragma omp parallel for schedule(dynamic)
  117. for (i = 0; i < nb_vectors; i++) {
  118. int tid;
  119. int id;
  120. double sign;
  121. double pvq_err;
  122. double err;
  123. tid=OD_OMP_GET_THREAD;
  124. id = find_nearest(&data[i*n], codebook, nb_entries, n, &sign, &err);
  125. pvq_err = pvq_dist_k2(&data[i*n], n);
  126. /*printf("%f ", err);*/
  127. if (!exclude_pvq || err < pvq_err) {
  128. int j;
  129. int offset;
  130. rms[tid] += err;
  131. offset = nb_entries*n*tid + id*n;
  132. for (j = 0; j < n; j++) accum[offset + j] += sign*data[i*n + j];
  133. }
  134. else rms[tid] += pvq_err;
  135. }
  136. for (i = 1; i < NUM_PROCS; i++) {
  137. int j;
  138. int offset;
  139. offset = nb_entries*n*i;
  140. for (j = 0; j < nb_entries*n; j++) accum[j] += accum[offset+j];
  141. }
  142. for (i = 1; i < NUM_PROCS; i++) rms[0] += rms[i];
  143. for (i = 0; i < nb_entries; i++) normalize(&accum[i*n], n);
  144. for (i = 0; i < nb_entries*n; i++) codebook[i] = accum[i];
  145. rms[0] = sqrt(rms[0]/nb_vectors);
  146. fprintf(stderr, "RMS: %f\n", rms[0]);
  147. }
  148. free(accum);
  149. return rms[0];
  150. }
  151. int main(int argc, char **argv)
  152. {
  153. int i;
  154. int j;
  155. int nb_vectors;
  156. int nb_entries;
  157. int ndim;
  158. double *data;
  159. double *codebook;
  160. double rms;
  161. unsigned seed;
  162. seed = time(NULL);
  163. srand(seed);
  164. if (argc != 4) {
  165. fprintf(stderr, "usage: %s <dimensions> <max vectors> <bits>\n",argc > 0? argv[0] : '\0');
  166. return 1;
  167. }
  168. ndim = atoi(argv[1]);
  169. nb_vectors = atoi(argv[2]);
  170. nb_entries = 1<<atoi(argv[3]);
  171. OD_OMP_SET_THREADS(NUM_PROCS);
  172. data = (double *)malloc(nb_vectors*ndim*sizeof(*data));
  173. codebook = (double *)malloc(nb_entries*ndim*sizeof(*codebook));
  174. if (data == NULL || codebook == NULL) {
  175. fprintf(stderr, "malloc() failed, giving up.\n");
  176. return 1;
  177. }
  178. for (i = 0;i < nb_vectors; i++) {
  179. if (feof(stdin))
  180. break;
  181. for (j = 0; j < ndim; j++) {
  182. if(scanf("%lf ", &data[i*ndim + j]) != 1) exit(EXIT_FAILURE);
  183. }
  184. normalize(&data[i*ndim], ndim);
  185. }
  186. nb_vectors = i;
  187. fprintf(stderr, "read %d vectors\n", nb_vectors);
  188. vq_rand_init(data, nb_vectors, codebook, nb_entries, ndim);
  189. rms = vq_train(data, nb_vectors, codebook, nb_entries, ndim, 100, 1);
  190. #if 0
  191. for (i = 0; i < nb_vectors; i++)
  192. {
  193. double sign;
  194. int nearest;
  195. nearest = find_nearest(&data[i*ndim], codebook, nb_entries, ndim, &sign,
  196. NULL);
  197. printf("%d %f\n", nearest, sign);
  198. }
  199. #endif
  200. printf("/* Automatically generated by vq_train. */\n");
  201. printf("/* Seed was %u. */\n", seed);
  202. printf("/* RMS training error is %f. */\n", rms);
  203. printf("const double codebook[%d*%d] = {\n", nb_entries, ndim);
  204. for (i = 0; i < nb_entries; i++) {
  205. for(j = 0; j < ndim; j++) printf("%f, ", codebook[i*ndim + j]);
  206. printf("\n");
  207. }
  208. printf("};\n");
  209. free(data);
  210. free(codebook);
  211. return 0;
  212. }