init_intra_xform.c 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. /*Daala video codec
  2. Copyright (c) 2012 Daala project contributors. All rights reserved.
  3. Redistribution and use in source and binary forms, with or without
  4. modification, are permitted provided that the following conditions are met:
  5. - Redistributions of source code must retain the above copyright notice, this
  6. list of conditions and the following disclaimer.
  7. - Redistributions in binary form must reproduce the above copyright notice,
  8. this list of conditions and the following disclaimer in the documentation
  9. and/or other materials provided with the distribution.
  10. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS”
  11. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  12. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  13. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
  14. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  15. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  16. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  17. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  18. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  19. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.*/
  20. #ifdef HAVE_CONFIG_H
  21. #include "config.h"
  22. #endif
  23. #include <stdio.h>
  24. #include <stdlib.h>
  25. #include <float.h>
  26. #include <limits.h>
  27. #include <math.h>
  28. #include <string.h>
  29. #include "intra_fit_tools.h"
  30. #include "svd.h"
  31. #include "cholesky.h"
  32. #include "../src/dct.h"
  33. #include "../src/intra.h"
  34. #define INTRA_NO_RDO (1)
  35. int ExCount[3];
  36. double Ex[3][B_SZ*B_SZ];
  37. #define P0_COEF 0.002
  38. #if !INTRA_NO_RDO
  39. /* These weight are estimated from the entropy of the residual (in one older run) on subset1-y4m */
  40. const float satd_weights[3][B_SZ*B_SZ] =
  41. {{0.053046, 0.108601, 0.208930, 0.267489, 0.099610, 0.151836, 0.276093, 0.333044,
  42. 0.180490, 0.256407, 0.420637, 0.485639, 0.224520, 0.295232, 0.464384, 0.519662,},
  43. {0.209979, 0.381558, 0.657913, 0.793790, 0.362177, 0.505781, 0.821064, 0.926135,
  44. 0.603069, 0.779381, 1.078780, 1.151472, 0.725735, 0.864047, 1.135539, 1.175140,},
  45. {0.231557, 0.398671, 0.687661, 0.820069, 0.379319, 0.542037, 0.865300, 0.967672,
  46. 0.643793, 0.824554, 1.124411, 1.191774, 0.760804, 0.909100, 1.181837, 1.223243,}
  47. };
  48. # if B_SZ==4
  49. /* Less extreme weighting -- sqrt of the weights above */
  50. const float satd_weights2[3][B_SZ*B_SZ] =
  51. {{0.230317, 0.329547, 0.457088, 0.517193, 0.315611, 0.389662, 0.525445, 0.577099,
  52. 0.424841, 0.506366, 0.648566, 0.696878, 0.473835, 0.543352, 0.681457, 0.720876,},
  53. {0.458235, 0.617704, 0.811118, 0.890949, 0.601811, 0.711183, 0.906126, 0.962359,
  54. 0.776575, 0.882826, 1.038644, 1.073067, 0.851901, 0.929541, 1.065617, 1.084039,},
  55. {0.481204, 0.631404, 0.829253, 0.905577, 0.615889, 0.736232, 0.930215, 0.983703,
  56. 0.802367, 0.908049, 1.060383, 1.091684, 0.872241, 0.953467, 1.087123, 1.106003,}
  57. };
  58. # else
  59. # error "No weights for B_SZ!=4 yet. #define INTRA_NO_RDO and try again."
  60. # endif
  61. #endif
  62. #if 1
  63. #define NB_CONTEXTS 8
  64. #define GET_CONTEXT(modes,pos,m,width) (((modes)[(pos)-1]==(m))*4 + ((modes)[(pos)-(width)-1]==(m))*2 + ((modes)[(pos)-(width)]==(m))*1)
  65. #else
  66. #define NB_CONTEXTS 1000
  67. #define GET_CONTEXT(modes,pos,m,width) (((modes)[(pos)-1])*100 + ((modes)[(pos)-(width)-1])*10 + ((modes)[(pos)-(width)]==(m))*1)
  68. #endif
  69. typedef struct intra_xform_ctx intra_xform_ctx;
  70. struct intra_xform_ctx{
  71. char *map_filename;
  72. unsigned char *map;
  73. char *weights_filename;
  74. unsigned *weights;
  75. int nxblocks;
  76. int nyblocks;
  77. int pli;
  78. double r_w[OD_INTRA_NMODES];
  79. double r_x[OD_INTRA_NMODES][2*B_SZ*2*B_SZ];
  80. double r_xx[OD_INTRA_NMODES][2*B_SZ*2*B_SZ][2*B_SZ*2*B_SZ];
  81. double scale[OD_INTRA_NMODES][2*B_SZ*2*B_SZ];
  82. double beta[OD_INTRA_NMODES][B_SZ*B_SZ][2*B_SZ*2*B_SZ];
  83. double freq[3][OD_INTRA_NMODES][NB_CONTEXTS][2];
  84. double p0[OD_INTRA_NMODES];
  85. long long n;
  86. double satd_avg;
  87. double total_bits;
  88. double total_satd;
  89. double total_count;
  90. };
  91. #define PRINT_BLOCKS (0)
  92. static int intra_xform_train_plane_start(void *_ctx,const char *_name,
  93. const video_input_info *_info,int _pli,int _nxblocks,int _nyblocks){
  94. intra_xform_ctx *ctx;
  95. FILE *map_file;
  96. char *map_filename;
  97. FILE *weights_file;
  98. char *weights_filename;
  99. (void)_info;
  100. ctx=(intra_xform_ctx *)_ctx;
  101. ctx->map=(unsigned char *)malloc(_nxblocks*(size_t)_nyblocks);
  102. map_filename=get_map_filename(_name,_pli,_nxblocks,_nyblocks);
  103. map_file=fopen(map_filename,"rb");
  104. if(map_file==NULL){
  105. fprintf(stderr,"Error opening input file '%s'.\n",map_filename);
  106. return EXIT_FAILURE;
  107. }
  108. ctx->map_filename=map_filename;
  109. if(fread(ctx->map,_nxblocks*(size_t)_nyblocks,1,map_file)<1){
  110. fprintf(stderr,"Error reading from input file '%s'.\n",map_filename);
  111. return EXIT_FAILURE;
  112. }
  113. fclose(map_file);
  114. ctx->weights=(unsigned *)malloc(
  115. _nxblocks*(size_t)_nyblocks*sizeof(*ctx->weights));
  116. weights_filename=get_weights_filename(_name,_pli,_nxblocks,_nyblocks);
  117. weights_file=fopen(weights_filename,"rb");
  118. if(weights_file==NULL){
  119. fprintf(stderr,"Error opening input file '%s'.\n",weights_filename);
  120. return EXIT_FAILURE;
  121. }
  122. ctx->weights_filename=weights_filename;
  123. if(fread(ctx->weights,
  124. _nxblocks*(size_t)_nyblocks*sizeof(*ctx->weights),1,weights_file)<1){
  125. fprintf(stderr,"Error reading from input file '%s'.\n",weights_filename);
  126. return EXIT_FAILURE;
  127. }
  128. fclose(weights_file);
  129. #if PRINT_BLOCKS
  130. fprintf(stderr,"%i %i\n",_nxblocks,_nyblocks);
  131. #endif
  132. ctx->nxblocks=_nxblocks;
  133. ctx->nyblocks=_nyblocks;
  134. ctx->pli=_pli;
  135. return EXIT_SUCCESS;
  136. }
  137. #define APPLY_PREFILTER (1)
  138. #define APPLY_DCT (1)
  139. static od_coeff *xform_blocks(od_coeff _buf[3*B_SZ*3*B_SZ],
  140. const unsigned char *_data,int _stride){
  141. od_coeff *buf2;
  142. od_coeff col[B_SZ];
  143. od_coeff *row;
  144. const unsigned char *origin;
  145. int bx;
  146. int by;
  147. int x;
  148. int j;
  149. int i;
  150. origin=_data-(3*B_SZ>>1)*_stride-(3*B_SZ>>1);
  151. for(by=0;by<3;by++){
  152. for(bx=0;bx<3;bx++){
  153. for(j=0;j<B_SZ;j++){
  154. x=B_SZ*bx+j;
  155. #if APPLY_PREFILTER
  156. for(i=0;i<B_SZ;i++)col[i]=origin[_stride*(B_SZ*by+i)+x]-128;
  157. # if B_SZ_LOG>=OD_LOG_BSIZE0&&B_SZ_LOG<OD_LOG_BSIZE0+OD_NBSIZES
  158. (*OD_PRE_FILTER[B_SZ_LOG-OD_LOG_BSIZE0])(col,col);
  159. # else
  160. # error "Need a prefilter implementation for this block size."
  161. # endif
  162. for(i=0;i<B_SZ;i++)_buf[3*B_SZ*(B_SZ*by+i)+x]=col[i];
  163. #else
  164. for(i=0;i<B_SZ;i++){
  165. _buf[3*B_SZ*(B_SZ*by+i)+x]=origin[_stride*(B_SZ*by+i)+x]-128;
  166. }
  167. #endif
  168. }
  169. }
  170. }
  171. #if APPLY_PREFILTER
  172. for(by=0;by<3;by++){
  173. for(bx=0;bx<3;bx++){
  174. for(i=0;i<B_SZ;i++){
  175. row=_buf+3*B_SZ*(B_SZ*by+i)+B_SZ*bx;
  176. (*OD_PRE_FILTER[B_SZ_LOG-OD_LOG_BSIZE0])(row,row);
  177. }
  178. }
  179. }
  180. #endif
  181. buf2=_buf+3*B_SZ*(B_SZ>>1)+(B_SZ>>1);
  182. #if APPLY_DCT
  183. for(by=0;by<2;by++){
  184. for(bx=0;bx<2;bx++){
  185. # if B_SZ_LOG>=OD_LOG_BSIZE0&&B_SZ_LOG<OD_LOG_BSIZE0+OD_NBSIZES
  186. (*OD_FDCT_2D_C[B_SZ_LOG-OD_LOG_BSIZE0])(buf2+B_SZ*(3*B_SZ*by+bx),3*B_SZ,
  187. buf2+B_SZ*(3*B_SZ*by+bx),3*B_SZ);
  188. # else
  189. # error "Need an fDCT implementation for this block size."
  190. # endif
  191. }
  192. }
  193. #endif
  194. return buf2;
  195. }
  196. static void intra_xform_train_block(void *_ctx,const unsigned char *_data,
  197. int _stride,int _bi,int _bj){
  198. intra_xform_ctx *ctx;
  199. double delta[2*B_SZ*2*B_SZ];
  200. double dw;
  201. od_coeff buf[3*B_SZ*3*B_SZ];
  202. od_coeff *buf2;
  203. int mode;
  204. double w;
  205. unsigned wb;
  206. int j;
  207. int i;
  208. int k;
  209. int l;
  210. ctx=(intra_xform_ctx *)_ctx;
  211. buf2=xform_blocks(buf,_data,_stride);
  212. mode=ctx->map[_bj*ctx->nxblocks+_bi];
  213. #if PRINT_BLOCKS
  214. fprintf(stderr,"%i",mode);
  215. for (i=0;i<2*B_SZ;i++) {
  216. for (j=0;j<2*B_SZ;j++) {
  217. fprintf(stderr," %i",buf2[3*B_SZ*i+j]);
  218. }
  219. }
  220. fprintf(stderr,"\n");
  221. #endif
  222. wb=ctx->weights[_bj*ctx->nxblocks+_bi];
  223. if(wb<=0)return;
  224. w=ctx->r_w[mode];
  225. ctx->r_w[mode]+=wb;
  226. dw=wb/(w+wb);
  227. for(i=0;i<2*B_SZ;i++){
  228. for(j=0;j<2*B_SZ;j++){
  229. int ci;
  230. ci=2*B_SZ*i+j;
  231. delta[ci]=(buf2[3*B_SZ*i+j]-ctx->r_x[mode][ci]);
  232. ctx->r_x[mode][ci]+=delta[ci]*dw;
  233. }
  234. }
  235. for(i=0;i<2*B_SZ;i++){
  236. for(j=0;j<2*B_SZ;j++){
  237. int ci;
  238. ci=2*B_SZ*i+j;
  239. for(k=0;k<2*B_SZ;k++){
  240. for(l=0;l<2*B_SZ;l++){
  241. int cj;
  242. cj=2*B_SZ*k+l;
  243. ctx->r_xx[mode][ci][cj]+=w*dw*delta[ci]*delta[cj];
  244. }
  245. }
  246. }
  247. }
  248. if (_bi>0 && _bj>0)
  249. {
  250. unsigned char *modes;
  251. int pos;
  252. int m;
  253. int width;
  254. modes=ctx->map;
  255. pos = _bj*ctx->nxblocks+_bi;
  256. width=ctx->nxblocks;
  257. for(m=0;m<OD_INTRA_NMODES;m++)
  258. {
  259. int c;
  260. c = GET_CONTEXT(modes,pos,m,width);
  261. ctx->freq[ctx->pli][m][c][0]+=1;
  262. ctx->freq[ctx->pli][m][c][1] += (mode==m);
  263. }
  264. }
  265. }
  266. static int intra_xform_train_plane_finish(void *_ctx){
  267. intra_xform_ctx *ctx;
  268. ctx=(intra_xform_ctx *)_ctx;
  269. free(ctx->weights_filename);
  270. free(ctx->weights);
  271. free(ctx->map_filename);
  272. free(ctx->map);
  273. return EXIT_SUCCESS;
  274. }
  275. static const char *MODE_NAME[OD_INTRA_NMODES]={
  276. "OD_INTRA_DC","OD_INTRA_TM","OD_INTRA_HU","OD_INTRA_HE","OD_INTRA_HD",
  277. "OD_INTRA_RD","OD_INTRA_VR","OD_INTRA_VE","OD_INTRA_VL","OD_INTRA_LD"
  278. };
  279. static void print_beta(int _mode,int _i,int _j,double *_beta){
  280. int i;
  281. int j;
  282. printf(" /*%s (%i,%i)*/\n",MODE_NAME[_mode],_i,_j);
  283. printf(" {\n");
  284. for(j=0;j<2*B_SZ;j++){
  285. if(j==B_SZ)printf("\n");
  286. printf(" {");
  287. /*printf(" {\n");
  288. printf(" ");*/
  289. for(i=0;i<2*B_SZ;i++){
  290. printf("%s%- 24.18G%s",i==B_SZ?" ":"",_beta[2*B_SZ*j+i],i<2*B_SZ-1?",":"");
  291. }
  292. /*printf(" }%s\n",j<2*B_SZ-1?",":"");*/
  293. printf("}%s\n",j<2*B_SZ-1?",":"");
  294. }
  295. printf(" }%s\n",_j<B_SZ-1?",":"");
  296. }
  297. typedef double r_xx_row[2*B_SZ*2*B_SZ];
  298. static void update_intra_xforms(intra_xform_ctx *_ctx){
  299. int mode;
  300. int pli;
  301. /*Update the model for each coefficient in each mode.*/
  302. printf("/* This file is generated automatically by init_intra_xform */\n");
  303. printf("#include \"intra.h\"\n");
  304. printf("\n");
  305. printf("const double OD_INTRA_PRED_WEIGHTS_%ix%i"
  306. "[OD_INTRA_NMODES][%i][%i][2*%i][2*%i]={\n",
  307. B_SZ,B_SZ,B_SZ,B_SZ,B_SZ,B_SZ);
  308. for(mode=0;mode<OD_INTRA_NMODES;mode++){
  309. int xi[2*B_SZ*2*B_SZ];
  310. int nxi;
  311. int i;
  312. int j;
  313. /*double *r_x;*/
  314. r_xx_row *r_xx;
  315. double *scale;
  316. /*r_x=_ctx->r_x[mode];*/
  317. r_xx=_ctx->r_xx[mode];
  318. scale=_ctx->scale[mode];
  319. printf(" {\n");
  320. for(i=0;i<2*B_SZ*2*B_SZ;i++){
  321. scale[i]=sqrt(r_xx[i][i]);
  322. if(scale[i]<=0)scale[i]=1;
  323. }
  324. for(i=0;i<2*B_SZ*2*B_SZ;i++){
  325. for(j=0;j<2*B_SZ*2*B_SZ;j++){
  326. r_xx[i][j]/=scale[i]*scale[j];
  327. }
  328. }
  329. nxi=0;
  330. for(j=0;j<B_SZ;j++){
  331. for(i=0;i<B_SZ;i++){
  332. xi[nxi]=2*B_SZ*j+i;
  333. xi[nxi+B_SZ*B_SZ]=2*B_SZ*j+B_SZ+i;
  334. xi[nxi+2*B_SZ*B_SZ]=2*B_SZ*(B_SZ+j)+i;
  335. nxi++;
  336. }
  337. }
  338. #if 0
  339. if(mode==0){
  340. for(i=0;i<2*B_SZ;i++){
  341. for(j=0;j<2*B_SZ;j++){
  342. int k;
  343. int l;
  344. for(k=0;k<2*B_SZ;k++){
  345. for(l=0;l<2*B_SZ;l++){
  346. printf("%0.18G%s",r_xx[2*B_SZ*i+j][2*B_SZ*k+l],2*B_SZ*k+l>=2*B_SZ*2*B_SZ-1?"\n":" ");
  347. }
  348. }
  349. }
  350. }
  351. }
  352. #endif
  353. for(i=0;i<B_SZ;i++){
  354. printf(" {\n");
  355. for(j=0;j<B_SZ;j++){
  356. double xty[2*B_SZ*2*B_SZ];
  357. double *beta;
  358. int xii;
  359. int xij;
  360. int yi;
  361. nxi=3*B_SZ*B_SZ;
  362. #if 0
  363. /*Include coefficients for the current block*/
  364. {
  365. int k;
  366. int l;
  367. for(k=0;k<=i;k++){
  368. for(l=0;l<=j;l++){
  369. xi[nxi++]=2*B_SZ*(B_SZ+k)+B_SZ+l;
  370. }
  371. }
  372. nxi--;
  373. }
  374. #endif
  375. yi=2*B_SZ*(B_SZ+i)+B_SZ+j;
  376. for(xii=0;xii<nxi;xii++)xty[xii]=r_xx[xi[xii]][yi];
  377. beta=_ctx->beta[mode][B_SZ*i+j];
  378. memset(beta,0,2*B_SZ*2*B_SZ*sizeof(*beta));
  379. #if defined(OD_USE_SVD)
  380. {
  381. double xtx[2*2*B_SZ*2*B_SZ][2*B_SZ*2*B_SZ];
  382. double *xtxp[2*2*B_SZ*2*B_SZ];
  383. double s[2*B_SZ*2*B_SZ];
  384. for(xii=0;xii<nxi;xii++){
  385. for(xij=0;xij<nxi;xij++){
  386. xtx[xii][xij]=r_xx[xi[xii]][xi[xij]];
  387. }
  388. }
  389. for(xii=0;xii<2*nxi;xii++)xtxp[xii]=xtx[xii];
  390. svd_pseudoinverse(xtxp,s,nxi,nxi);
  391. /*beta[yi]=r_x[yi];*/
  392. for(xii=0;xii<nxi;xii++){
  393. double beta_i;
  394. beta_i=0;
  395. for(xij=0;xij<nxi;xij++)beta_i+=xtx[xij][xii]*xty[xij];
  396. beta[xi[xii]]=beta_i*scale[yi]/scale[xi[xii]];
  397. /*beta[yi]-=beta_i*r_x[xi[xii]];*/
  398. }
  399. }
  400. #else
  401. {
  402. double xtx[UT_SZ(2*B_SZ*2*B_SZ,2*B_SZ*2*B_SZ)];
  403. double tau[2*B_SZ*2*B_SZ];
  404. double work[2*B_SZ*2*B_SZ];
  405. int pivot[2*B_SZ*2*B_SZ];
  406. int rank;
  407. for(xii=0;xii<nxi;xii++){
  408. for(xij=xii;xij<nxi;xij++){
  409. xtx[UT_IDX(xii,xij,nxi)]=r_xx[xi[xii]][xi[xij]];
  410. }
  411. }
  412. rank=cholesky(xtx,pivot,DBL_EPSILON,nxi);
  413. chdecomp(xtx,tau,rank,nxi);
  414. chsolve(xtx,pivot,tau,xty,xty,work,rank,nxi);
  415. for(xii=0;xii<nxi;xii++){
  416. beta[xi[xii]]=xty[xii]*scale[yi]/scale[xi[xii]];
  417. /*beta[yi]-=beta_i*r_x[xi[xii]];*/
  418. }
  419. }
  420. #endif
  421. print_beta(mode,i,j,beta);
  422. }
  423. printf(" }%s\n",i<B_SZ-1?",":"");
  424. }
  425. printf(" }%s\n",mode<OD_INTRA_NMODES-1?",":"");
  426. }
  427. printf("};\n\n");
  428. printf("const unsigned char OD_INTRA_PRED_PROB_%dx%d[3][OD_INTRA_NMODES][OD_INTRA_NCONTEXTS]={\n",B_SZ,B_SZ);
  429. for(pli=0;pli<3;pli++)
  430. {
  431. int i;
  432. printf("{");
  433. for(i=0;i<OD_INTRA_NMODES;i++)
  434. {
  435. int j;
  436. printf("{");
  437. for(j=0;j<NB_CONTEXTS;j++)
  438. printf("%d, ", (int)floor(.5+256.*_ctx->freq[pli][i][j][1]/(float)_ctx->freq[pli][i][j][0]));
  439. printf("},\n");
  440. }
  441. printf("},\n");
  442. }
  443. printf("};\n\n");
  444. }
  445. static int intra_xform_update_plane_start(void *_ctx,const char *_name,
  446. const video_input_info *_info,int _pli,int _nxblocks,int _nyblocks){
  447. intra_xform_ctx *ctx;
  448. int i;
  449. (void)_info;
  450. ctx=(intra_xform_ctx *)_ctx;
  451. ctx->map_filename=get_map_filename(_name,_pli,_nxblocks,_nyblocks);
  452. ctx->weights_filename=get_weights_filename(_name,_pli,_nxblocks,_nyblocks);
  453. ctx->map=(unsigned char *)malloc(_nxblocks*(size_t)_nyblocks);
  454. ctx->weights=(unsigned *)malloc(
  455. _nxblocks*(size_t)_nyblocks*sizeof(*ctx->weights));
  456. ctx->nxblocks=_nxblocks;
  457. ctx->nyblocks=_nyblocks;
  458. ctx->pli=_pli;
  459. for(i=0;i<OD_INTRA_NMODES;i++)
  460. ctx->p0[i]=ctx->freq[ctx->pli][i][0][1]/(float)ctx->freq[ctx->pli][i][0][0];
  461. return EXIT_SUCCESS;
  462. }
  463. static void intra_xform_update_block(void *_ctx,const unsigned char *_data,
  464. int _stride,int _bi,int _bj){
  465. intra_xform_ctx *ctx;
  466. od_coeff buf[3*B_SZ*3*B_SZ];
  467. od_coeff *buf2;
  468. double best_satd;
  469. double best_rlsatd;
  470. double best_bits;
  471. double next_best_satd;
  472. double next_best_rlsatd;
  473. double best_error[B_SZ*B_SZ]={0};
  474. double error[B_SZ*B_SZ];
  475. int mode;
  476. int best_mode;
  477. int c0[OD_INTRA_NMODES]={0};
  478. double bits=0;
  479. int i;
  480. float sum=0;
  481. float sum2=0;
  482. unsigned char *modes;
  483. int pos;
  484. int m;
  485. int width;
  486. float p[OD_INTRA_NMODES];
  487. ogg_uint16_t cdf[OD_INTRA_NMODES];
  488. /*If using this be sure to uncomment its assignment.*/
  489. /* ogg_uint32_t wsatd[OD_INTRA_NMODES];*/
  490. ctx=(intra_xform_ctx *)_ctx;
  491. modes=ctx->map;
  492. pos = _bj*ctx->nxblocks+_bi;
  493. width=ctx->nxblocks;
  494. buf2=xform_blocks(buf,_data,_stride);
  495. best_mode=0;
  496. best_satd=UINT_MAX;
  497. best_rlsatd=UINT_MAX;
  498. best_bits=0;
  499. next_best_satd=UINT_MAX;
  500. next_best_rlsatd=UINT_MAX;
  501. {
  502. int c;
  503. unsigned char probs[OD_INTRA_NMODES][OD_INTRA_NCONTEXTS];
  504. int left;
  505. int upleft;
  506. int up;
  507. left=(_bi==0)?0:modes[pos-1];
  508. up=(_bj==0)?0:modes[pos-width];
  509. upleft=(_bi==0||_bj==0)?0:modes[pos-width-1];
  510. for (m=0;m<OD_INTRA_NMODES;m++)
  511. for(c=0;c<OD_INTRA_NCONTEXTS;c++)
  512. probs[m][c] = 256.*ctx->freq[ctx->pli][m][c][1]/(float)ctx->freq[ctx->pli][m][c][0];
  513. od_intra_pred_cdf(cdf,probs,OD_INTRA_NMODES,left,upleft,up);
  514. }
  515. for(m=0;m<OD_INTRA_NMODES;m++)
  516. {
  517. int c;
  518. if (_bi>0 && _bj>0)
  519. c=GET_CONTEXT(modes,pos,m,width);
  520. else if (_bj>0)
  521. #if 1
  522. c=(modes[pos-width]==m)+6*(m==0);
  523. #else
  524. c=(modes[pos-width]);
  525. #endif
  526. else if (_bi>0)
  527. #if 1
  528. c=(modes[pos-1]==m)*4+3*(m==0);
  529. #else
  530. c=(modes[pos-1])*100;
  531. #endif
  532. else
  533. c=15*(m==0);
  534. p[m] = ctx->freq[ctx->pli][m][c][1]/(float)ctx->freq[ctx->pli][m][c][0];
  535. p[m]+=1.e-5;
  536. if (p[m]<ctx->p0[m]) p[m]=ctx->p0[m];
  537. if (c==0)
  538. {
  539. ctx->p0[m]*=(1-P0_COEF);
  540. c0[m]=1;
  541. }
  542. sum += p[m];
  543. }
  544. for(m=0;m<OD_INTRA_NMODES;m++)
  545. {
  546. p[m] *= (sum-p[m])/(1-p[m]);
  547. sum2+=p[m];
  548. }
  549. for(mode=0;mode<OD_INTRA_NMODES;mode++){
  550. double satd;
  551. double rlsatd;
  552. double diff;
  553. int j;
  554. satd=0;
  555. for(i=0;i<B_SZ;i++){
  556. for(j=0;j<B_SZ;j++){
  557. const double *beta;
  558. double p;
  559. int k;
  560. int l;
  561. beta=ctx->beta[mode][B_SZ*i+j];
  562. p=0;
  563. for(k=0;k<2*B_SZ;k++){
  564. for(l=0;l<2*B_SZ;l++){
  565. p+=beta[2*B_SZ*k+l]*buf2[3*B_SZ*k+l];
  566. }
  567. }
  568. #ifdef INTRA_NO_RDO
  569. diff = fabs(buf2[3*B_SZ*(i+B_SZ)+j+B_SZ]-(od_coeff)floor(p+0.5));
  570. satd+=diff;
  571. #else
  572. /* Simulates quantization with dead zone (without the annoying quantization effects) */
  573. diff = fabs(buf2[3*B_SZ*(i+B_SZ)+j+B_SZ]-p) - 1;
  574. if (diff<0)diff=0;
  575. /*satd+=satd_weights2[i*B_SZ+j]*diff;*/
  576. satd+=satd_weights2[ctx->pli][i*B_SZ+j]*diff;
  577. #endif
  578. error[i*B_SZ+j]=buf2[3*B_SZ*(i+B_SZ)+j+B_SZ]-p;
  579. }
  580. }
  581. /* wsatd[mode] = satd*64;*/
  582. rlsatd=satd;
  583. /* Normalize all probabilities except the max */
  584. /*bits = -log(p[mode]/sum)/log(2);*/
  585. /* Normalize all probabilities except the max */
  586. /*bits = (mode==maxM) ? -log(p[mode])/log(2) : -log(p[mode]*(1-maxP)/(sum-maxP))/log(2);*/
  587. bits = -log(p[mode]/sum2)/log(2);
  588. /*printf("{%f+l*%f= ", satd , bits);*/
  589. #ifndef INTRA_NO_RDO
  590. satd += 1.1*bits;
  591. #endif
  592. /*printf("%f} ", satd);*/
  593. /* Bias towards DC mode */
  594. /*if (mode==0)satd-=.5;*/
  595. if(satd<best_satd){
  596. next_best_satd=best_satd;
  597. next_best_rlsatd=best_rlsatd;
  598. best_satd=satd;
  599. best_rlsatd=rlsatd;
  600. best_mode=mode;
  601. best_bits=bits;
  602. for (i=0;i<B_SZ*B_SZ;i++)
  603. best_error[i]=error[i];
  604. }
  605. else if(satd<next_best_satd){next_best_satd=satd;next_best_rlsatd=rlsatd;}
  606. }
  607. /*printf("\n");*/
  608. #if 0
  609. {
  610. int bmode;
  611. int left, up, upleft;
  612. left=(_bi==0)?0:modes[pos-1];
  613. up=(_bj==0)?0:modes[pos-width];
  614. upleft=(_bi==0||_bj==0)?0:modes[pos-width-1];
  615. bmode=od_intra_pred_search(cdf,wsatd,OD_INTRA_NMODES,64*1.1);
  616. od_intra_pred_update(p0,OD_INTRA_NMODES,bmode,left,upleft,up);
  617. /*if (bmode==best_mode)
  618. printf("+");
  619. else
  620. printf("-");
  621. printf("%d %d\n", bmode, best_mode);*/
  622. }
  623. #endif
  624. for (i=0;i<B_SZ*B_SZ;i++)
  625. Ex[ctx->pli][i]+=fabs(best_error[i]);
  626. ExCount[ctx->pli]++;
  627. if (c0[best_mode])
  628. ctx->p0[best_mode] += P0_COEF;
  629. /*fprintf(stderr,"%f\n", best_bits);*/
  630. ctx->total_bits += best_bits;
  631. ctx->total_satd+=best_satd;
  632. ctx->total_count += 1;
  633. ctx->satd_avg+=(best_satd-ctx->satd_avg)/++(ctx->n);
  634. ctx->map[_bj*ctx->nxblocks+_bi]=best_mode;
  635. ctx->weights[_bj*ctx->nxblocks+_bi]=floor((next_best_rlsatd-best_rlsatd)*1000.);
  636. if(next_best_rlsatd<=best_rlsatd)ctx->weights[_bj*ctx->nxblocks+_bi]=0;
  637. if(best_mode==0)ctx->weights[_bj*ctx->nxblocks+_bi]=1;
  638. }
  639. static int intra_xform_update_plane_finish(void *_ctx){
  640. intra_xform_ctx *ctx;
  641. FILE *map_file;
  642. FILE *weights_file;
  643. ctx=(intra_xform_ctx *)_ctx;
  644. map_file=fopen(ctx->map_filename,"wb");
  645. if(map_file==NULL){
  646. fprintf(stderr,"Error opening output file '%s'.\n",ctx->map_filename);
  647. return EXIT_FAILURE;
  648. }
  649. if(fwrite(ctx->map,ctx->nxblocks*(size_t)ctx->nyblocks,1,map_file)<1){
  650. fprintf(stderr,"Error writing to output file '%s'.\n",ctx->map_filename);
  651. return EXIT_FAILURE;
  652. }
  653. fclose(map_file);
  654. weights_file=fopen(ctx->weights_filename,"wb");
  655. if(weights_file==NULL){
  656. fprintf(stderr,"Error opening output file '%s'.\n",ctx->weights_filename);
  657. return EXIT_FAILURE;
  658. }
  659. if(fwrite(ctx->weights,
  660. ctx->nxblocks*(size_t)ctx->nyblocks*sizeof(*ctx->weights),1,
  661. weights_file)<1){
  662. fprintf(stderr,"Error writing to output file '%s'.\n",
  663. ctx->weights_filename);
  664. return EXIT_FAILURE;
  665. }
  666. fclose(weights_file);
  667. /*printf("Average SATD: %G\n",ctx->satd_avg);*/
  668. return intra_xform_train_plane_finish(_ctx);
  669. }
  670. int main(int _argc,const char **_argv){
  671. static intra_xform_ctx ctx;
  672. int ret;
  673. int i;
  674. int pli;
  675. for(pli=0;pli<3;pli++)
  676. {
  677. for(i=0;i<OD_INTRA_NMODES;i++)
  678. {
  679. int j;
  680. for(j=0;j<NB_CONTEXTS;j++)
  681. {
  682. ctx.freq[pli][i][j][0]=2;
  683. ctx.freq[pli][i][j][1]=1;
  684. }
  685. }
  686. }
  687. ctx.total_bits=0;
  688. ctx.total_count=0;
  689. ctx.total_satd=0;
  690. ret=apply_to_blocks(&ctx,intra_xform_train_plane_start,
  691. intra_xform_train_block,intra_xform_train_plane_finish,_argc,_argv);
  692. if(ret==EXIT_SUCCESS){
  693. update_intra_xforms(&ctx);
  694. ret=apply_to_blocks(&ctx,intra_xform_update_plane_start,
  695. intra_xform_update_block,intra_xform_update_plane_finish,_argc,_argv);
  696. }
  697. #if 0
  698. for (pli=0;pli<3;pli++)
  699. {
  700. printf("Ex: ");
  701. for(i=0;i<B_SZ*B_SZ;i++)
  702. {
  703. printf("%f ", Ex[pli][i]/ExCount[pli]);
  704. }
  705. printf("\n");
  706. }
  707. #endif
  708. fprintf(stderr, "Average cost: %f bits/block, satd+cost: %f\n", ctx.total_bits/ctx.total_count, ctx.total_satd/ctx.total_count);
  709. return ret;
  710. }