intra_pred.c 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. #include <float.h>
  2. #include <omp.h>
  3. #include <stdlib.h>
  4. #include <sys/timeb.h>
  5. #include "cholesky.h"
  6. #include "od_defs.h"
  7. #include "od_covmat.h"
  8. #include "od_filter.h"
  9. #include "od_intra.h"
  10. #include "image_tools.h"
  11. #include "stats_tools.h"
  12. #include "svd.h"
  13. #define USE_SVD (0)
  14. #define BITS_SELECT (0)
  15. #define USE_WEIGHTS (0)
  16. #define SUBTRACT_DC (0)
  17. #define POOLED_COV (1)
  18. #define WRITE_IMAGES (0)
  19. #define PRINT_PROGRESS (0)
  20. #define PRINT_BLOCKS (0)
  21. #define PRINT_COMP (0)
  22. #define PRINT_DROPS (0)
  23. #define PRINT_BETAS (0)
  24. typedef struct classify_ctx classify_ctx;
  25. struct classify_ctx{
  26. int n;
  27. intra_stats st;
  28. intra_stats gb;
  29. od_covmat pd[OD_INTRA_NMODES];
  30. image_data img;
  31. #if WRITE_IMAGES
  32. image_files files;
  33. #endif
  34. double bits;
  35. };
  36. static void classify_ctx_init(classify_ctx *_this){
  37. int i;
  38. _this->n=0;
  39. intra_stats_init(&_this->st);
  40. intra_stats_init(&_this->gb);
  41. for(i=0;i<OD_INTRA_NMODES;i++){
  42. od_covmat_init(&_this->pd[i],5*B_SZ*B_SZ);
  43. }
  44. }
  45. static void classify_ctx_clear(classify_ctx *_this){
  46. int i;
  47. intra_stats_clear(&_this->st);
  48. intra_stats_clear(&_this->gb);
  49. for(i=0;i<OD_INTRA_NMODES;i++){
  50. od_covmat_clear(&_this->pd[i]);
  51. }
  52. }
  53. static void classify_ctx_reset(classify_ctx *_this){
  54. int i;
  55. _this->n=0;
  56. intra_stats_reset(&_this->gb);
  57. for(i=0;i<OD_INTRA_NMODES;i++){
  58. od_covmat_reset(&_this->pd[i]);
  59. }
  60. _this->bits=0;
  61. }
  62. static void classify_ctx_set_image(classify_ctx *_this,const char *_name,
  63. int _nxblocks,int _nyblocks){
  64. _this->n++;
  65. intra_stats_reset(&_this->st);
  66. image_data_init(&_this->img,_name,_nxblocks,_nyblocks);
  67. #if WRITE_IMAGES
  68. image_files_init(&_this->files,_nxblocks,_nyblocks);
  69. #endif
  70. }
  71. static void classify_ctx_clear_image(classify_ctx *_this){
  72. image_data_clear(&_this->img);
  73. #if WRITE_IMAGES
  74. image_files_clear(&_this->files);
  75. #endif
  76. }
  77. typedef struct prob_ctx prob_ctx;
  78. struct prob_ctx{
  79. double *scale;
  80. double *xtx;
  81. double *xty;
  82. const double *mean;
  83. const double *cov;
  84. double *ete;
  85. };
  86. static void prob_ctx_init(prob_ctx *_this){
  87. _this->scale=(double *)malloc(sizeof(*_this->scale)*5*B_SZ*B_SZ);
  88. _this->xtx=(double *)malloc(sizeof(*_this->xtx)*4*B_SZ*B_SZ*4*B_SZ*B_SZ);
  89. _this->xty=(double *)malloc(sizeof(*_this->xty)*4*B_SZ*B_SZ*B_SZ*B_SZ);
  90. _this->mean=NULL;
  91. _this->cov=NULL;
  92. _this->ete=(double *)malloc(sizeof(*_this->ete)*B_SZ*B_SZ*B_SZ*B_SZ);
  93. }
  94. static void prob_ctx_clear(prob_ctx *_this){
  95. free(_this->scale);
  96. free(_this->xtx);
  97. free(_this->xty);
  98. _this->mean=NULL;
  99. _this->cov=NULL;
  100. free(_this->ete);
  101. }
  102. static void prob_ctx_load(prob_ctx *_this,od_covmat *_mat){
  103. int i;
  104. int j;
  105. /* compute the scale factors */
  106. for(i=0;i<5*B_SZ*B_SZ;i++){
  107. _this->scale[i]=sqrt(_mat->cov[i*5*B_SZ*B_SZ+i]);
  108. }
  109. /* normalize X^T*X and X^T*Y */
  110. for(j=0;j<4*B_SZ*B_SZ;j++){
  111. for(i=0;i<4*B_SZ*B_SZ;i++){
  112. _this->xtx[4*B_SZ*B_SZ*j+i]=
  113. _mat->cov[5*B_SZ*B_SZ*j+i]/(_this->scale[j]*_this->scale[i]);
  114. }
  115. for(i=0;i<B_SZ*B_SZ;i++){
  116. _this->xty[B_SZ*B_SZ*j+i]=_mat->cov[5*B_SZ*B_SZ*j+4*B_SZ*B_SZ+i];
  117. _this->xty[B_SZ*B_SZ*j+i]/=_this->scale[j]*_this->scale[4*B_SZ*B_SZ+i];
  118. }
  119. }
  120. _this->mean=_mat->mean;
  121. _this->cov=_mat->cov;
  122. }
  123. static void prob_ctx_comp_error(prob_ctx *_this,od_covmat *_mat,double *_beta_1){
  124. int j;
  125. int i;
  126. for(j=0;j<B_SZ*B_SZ;j++){
  127. for(i=0;i<B_SZ*B_SZ;i++){
  128. int ji;
  129. int l;
  130. int k;
  131. ji=B_SZ*B_SZ*j+i;
  132. l=5*B_SZ*B_SZ*(4*B_SZ*B_SZ+j);
  133. /* E^T*E = Y^T*Y - Y^T*X * beta_1 */
  134. _this->ete[ji]=_mat->cov[l+4*B_SZ*B_SZ+i];
  135. for(k=0;k<4*B_SZ*B_SZ;k++){
  136. _this->ete[ji]-=_mat->cov[l+k]*_beta_1[4*B_SZ*B_SZ*i+k];
  137. }
  138. }
  139. }
  140. #if PRINT_COMP
  141. fprintf(stderr,"ete=[");
  142. for(j=0;j<B_SZ*B_SZ;j++){
  143. fprintf(stderr,"%s",j!=0?";":"");
  144. for(i=0;i<B_SZ*B_SZ;i++){
  145. fprintf(stderr,"%s%- 24.18G",i!=0?",":"",_this->ete[B_SZ*B_SZ*j+i]);
  146. }
  147. }
  148. fprintf(stderr,"];\n");
  149. #endif
  150. }
  151. static void update_diversity(const double *_ete,double _b[B_SZ*B_SZ],
  152. const double *_scale){
  153. int v;
  154. int u;
  155. for(v=0;v<B_SZ;v++){
  156. for(u=0;u<B_SZ;u++){
  157. int i;
  158. int ii;
  159. i=B_SZ*v+u;
  160. ii=B_SZ*B_SZ*i+i;
  161. _b[i]=sqrt(_scale[v]*_scale[u]*_ete[ii]/2);
  162. }
  163. }
  164. }
  165. typedef struct solve_ctx solve_ctx;
  166. struct solve_ctx{
  167. #if USE_SVD
  168. double *xtx;
  169. double **xtxp;
  170. double *s;
  171. #else
  172. double *r;
  173. int *pivot;
  174. double *tau;
  175. double *b;
  176. double *work;
  177. #endif
  178. double *beta_0;
  179. double *beta_1;
  180. };
  181. static void solve_ctx_init(solve_ctx *_this){
  182. #if USE_SVD
  183. _this->xtx=(double *)malloc(sizeof(*_this->xtx)*2*4*B_SZ*B_SZ*4*B_SZ*B_SZ);
  184. _this->xtxp=(double **)malloc(sizeof(*_this->xtxp)*2*4*B_SZ*B_SZ);
  185. _this->s=(double *)malloc(sizeof(*_this->s)*4*B_SZ*B_SZ);
  186. #else
  187. _this->r=(double *)malloc(sizeof(*_this->r)*UT_SZ(4*B_SZ*B_SZ,4*B_SZ*B_SZ));
  188. _this->pivot=(int *)malloc(sizeof(*_this->pivot)*4*B_SZ*B_SZ);
  189. _this->tau=(double *)malloc(sizeof(*_this->tau)*4*B_SZ*B_SZ);
  190. _this->b=(double *)malloc(sizeof(*_this->b)*4*B_SZ*B_SZ);
  191. _this->work=(double *)malloc(sizeof(*_this->work)*4*B_SZ*B_SZ);
  192. #endif
  193. _this->beta_0=(double *)malloc(sizeof(*_this->beta_0)*B_SZ*B_SZ);
  194. _this->beta_1=(double *)malloc(sizeof(*_this->beta_1)*B_SZ*B_SZ*4*B_SZ*B_SZ);
  195. }
  196. static void solve_ctx_clear(solve_ctx *_this){
  197. #if USE_SVD
  198. free(_this->xtx);
  199. free(_this->xtxp);
  200. free(_this->s);
  201. #else
  202. free(_this->r);
  203. free(_this->pivot);
  204. free(_this->tau);
  205. free(_this->b);
  206. free(_this->work);
  207. #endif
  208. free(_this->beta_0);
  209. free(_this->beta_1);
  210. }
  211. /* solve for beta_0[_y] and beta_1[_y] */
  212. static void solve(const prob_ctx *_prob,solve_ctx *_sol,int _y,int *_mask,
  213. double *_beta_0,double *_beta_1){
  214. int nmi;
  215. int mi[4*B_SZ*B_SZ];
  216. int i;
  217. int j;
  218. #if !USE_SVD
  219. int rank;
  220. #endif
  221. nmi=0;
  222. for(i=0;i<4*B_SZ*B_SZ;i++){
  223. if(_mask[i]){
  224. mi[nmi]=i;
  225. nmi++;
  226. }
  227. _beta_1[_y*4*B_SZ*B_SZ+i]=0;
  228. }
  229. #if USE_SVD
  230. for(j=0;j<nmi;j++){
  231. for(i=0;i<nmi;i++){
  232. _sol->xtx[4*B_SZ*B_SZ*j+i]=_prob->xtx[4*B_SZ*B_SZ*mi[j]+mi[i]];
  233. }
  234. }
  235. for(i=0;i<2*nmi;i++){
  236. _sol->xtxp[i]=&_sol->xtx[4*B_SZ*B_SZ*i];
  237. }
  238. svd_pseudoinverse(_sol->xtxp,_sol->s,nmi,nmi);
  239. #else
  240. for(j=0;j<nmi;j++){
  241. for(i=j;i<nmi;i++){
  242. _sol->r[UT_IDX(j,i,nmi)]=_prob->xtx[4*B_SZ*B_SZ*mi[j]+mi[i]];
  243. }
  244. _sol->b[j]=_prob->xty[B_SZ*B_SZ*mi[j]+_y];
  245. }
  246. rank=cholesky(_sol->r,_sol->pivot,DBL_EPSILON,nmi);
  247. chdecomp(_sol->r,_sol->tau,rank,nmi);
  248. chsolve(_sol->r,_sol->pivot,_sol->tau,_sol->b,_sol->b,_sol->work,rank,nmi);
  249. #endif
  250. /* compute beta_1 = (X^T*X)^-1 * X^T*Y and beta_0 = Ym - Xm * beta_1 */
  251. _beta_0[_y]=_prob->mean[4*B_SZ*B_SZ+_y];
  252. for(j=0;j<nmi;j++){
  253. int yj;
  254. yj=_y*4*B_SZ*B_SZ+mi[j];
  255. #if USE_SVD
  256. _beta_1[yj]=0;
  257. for(i=0;i<nmi;i++){
  258. _beta_1[yj]+=_sol->xtx[4*B_SZ*B_SZ*j+i]*_prob->xty[B_SZ*B_SZ*mi[i]+_y];
  259. }
  260. #else
  261. _beta_1[yj]=_sol->b[j];
  262. #endif
  263. _beta_1[yj]*=_prob->scale[4*B_SZ*B_SZ+_y]/_prob->scale[mi[j]];
  264. _beta_0[_y]-=_prob->mean[mi[j]]*_beta_1[yj];
  265. }
  266. }
  267. static double comp_error(const prob_ctx *_prob,solve_ctx *_sol,int _y,
  268. int *_mask){
  269. double err;
  270. int j;
  271. int i;
  272. solve(_prob,_sol,_y,_mask,_sol->beta_0,_sol->beta_1);
  273. j=4*B_SZ*B_SZ+_y;
  274. err=_prob->cov[5*B_SZ*B_SZ*j+j];
  275. for(i=0;i<4*B_SZ*B_SZ;i++){
  276. if(_mask[i]){
  277. err-=_prob->cov[5*B_SZ*B_SZ*j+i]*_sol->beta_1[4*B_SZ*B_SZ*_y+i];
  278. }
  279. }
  280. return(err);
  281. }
  282. static int comp_delta_pg(const prob_ctx *_prob,solve_ctx _sol[NUM_PROCS],int _y,
  283. int _mask[4*B_SZ*B_SZ],double *_delta_pg){
  284. double s;
  285. int i;
  286. int j;
  287. int nmi;
  288. int mi[4*B_SZ*B_SZ];
  289. int mask[NUM_PROCS][4*B_SZ*B_SZ];
  290. double delta_pg[4*B_SZ*B_SZ];
  291. nmi=0;
  292. for(i=0;i<4*B_SZ*B_SZ;i++){
  293. if(_mask[i]){
  294. mi[nmi]=i;
  295. nmi++;
  296. }
  297. for(j=0;j<NUM_PROCS;j++){
  298. mask[j][i]=_mask[i];
  299. }
  300. }
  301. s=1/comp_error(_prob,&_sol[0],_y,_mask);
  302. #pragma omp parallel for schedule(dynamic)
  303. for(i=0;i<nmi;i++){
  304. int tid;
  305. tid=omp_get_thread_num();
  306. mask[tid][mi[i]]=0;
  307. delta_pg[i]=comp_error(_prob,&_sol[tid],_y,mask[tid])*s;
  308. mask[tid][mi[i]]=1;
  309. }
  310. #if SUBTRACT_DC
  311. if(_y==0) {
  312. for(i=0;i<nmi;i++){
  313. if(mi[i]%(B_SZ*B_SZ)==0){
  314. delta_pg[i]=UINT_MAX;
  315. }
  316. }
  317. }
  318. #endif
  319. j=-1;
  320. for(i=0;i<nmi;i++){
  321. if(j==-1||delta_pg[i]<delta_pg[j]){
  322. j=i;
  323. }
  324. }
  325. if(j==-1){
  326. return j;
  327. }
  328. *_delta_pg=delta_pg[j];
  329. return(mi[j]);
  330. }
  331. static long timing(const struct timeb *_start,const struct timeb *_stop){
  332. long ms;
  333. ms=(_stop->time-_start->time)*1000;
  334. ms+=(_stop->millitm-_start->millitm);
  335. return ms;
  336. }
  337. static void comp_predictors(const prob_ctx *_prob,solve_ctx _sol[NUM_PROCS],
  338. int _drop,int _mask[B_SZ*B_SZ*4*B_SZ*B_SZ]){
  339. int i;
  340. int j;
  341. #if PRINT_COMP
  342. fprintf(stderr,"xtx=[");
  343. for(j=0;j<4*B_SZ*B_SZ;j++){
  344. fprintf(stderr,"%s",j!=0?";":"");
  345. for(i=0;i<4*B_SZ*B_SZ;i++){
  346. fprintf(stderr,"%s%- 24.18G",i!=0?",":"",_prob->xtx[4*B_SZ*B_SZ*j+i]);
  347. }
  348. }
  349. fprintf(stderr,"];\n");
  350. fprintf(stderr,"xty=[");
  351. for(j=0;j<4*B_SZ*B_SZ;j++){
  352. fprintf(stderr,"%s",j!=0?";":"");
  353. for(i=0;i<B_SZ*B_SZ;i++){
  354. fprintf(stderr,"%s%- 24.18G",i!=0?",":"",_prob->xty[B_SZ*B_SZ*j+i]);
  355. }
  356. }
  357. fprintf(stderr,"];\n");
  358. #endif
  359. if(_drop>0){
  360. double delta_pg[B_SZ*B_SZ];
  361. int idx[B_SZ*B_SZ];
  362. for(j=0;j<B_SZ*B_SZ;j++){
  363. idx[j]=comp_delta_pg(_prob,_sol,j,&_mask[j*4*B_SZ*B_SZ],&delta_pg[j]);
  364. }
  365. while(_drop-->0){
  366. j=-1;
  367. for(i=0;i<B_SZ*B_SZ;i++){
  368. if(idx[i]!=-1&&(j==-1||delta_pg[i]<delta_pg[j])){
  369. j=i;
  370. }
  371. }
  372. #if PRINT_DROPS
  373. printf("Dropping (%2i,%2i) cost Pg=%g\n",j,idx[j],10*log10(delta_pg[j]));
  374. fflush(stdout);
  375. #endif
  376. _mask[j*4*B_SZ*B_SZ+idx[j]]=0;
  377. idx[j]=comp_delta_pg(_prob,_sol,j,&_mask[j*4*B_SZ*B_SZ],&delta_pg[j]);
  378. }
  379. }
  380. #pragma omp parallel for schedule(dynamic)
  381. for(j=0;j<B_SZ*B_SZ;j++){
  382. int tid;
  383. tid=omp_get_thread_num();
  384. solve(_prob,&_sol[tid],j,&_mask[j*4*B_SZ*B_SZ],_sol->beta_0,_sol->beta_1);
  385. }
  386. #if PRINT_COMP
  387. fprintf(stderr,"beta_1=[");
  388. for(j=0;j<4*B_SZ*B_SZ;j++){
  389. fprintf(stderr,"%s",j!=0?";":"");
  390. for(i=0;i<B_SZ*B_SZ;i++){
  391. fprintf(stderr,"%s%- 24.18G",i!=0?",":"",_sol->beta_1[4*B_SZ*B_SZ*i+j]);
  392. }
  393. }
  394. fprintf(stderr,"];\n");
  395. fprintf(stderr,"beta_0=[");
  396. for(i=0;i<B_SZ*B_SZ;i++){
  397. fprintf(stderr,"%s%- 24.18G",i!=0?",":"",_sol->beta_0[i]);
  398. }
  399. fprintf(stderr,"];\n");
  400. #endif
  401. }
  402. #if PRINT_PROGRESS
  403. static void print_progress(FILE *_fp,const char *_proc){
  404. int tid;
  405. tid=omp_get_thread_num();
  406. fprintf(_fp,"thread %i in %s\n",tid,_proc);
  407. }
  408. #endif
  409. static void ip_pre_block(void *_ctx,const unsigned char *_data,int _stride,
  410. int _bi,int _bj){
  411. classify_ctx *ctx;
  412. #if PRINT_PROGRESS
  413. if(_bi==0&&_bj==0){
  414. print_progress(stdout,"ip_pre_block");
  415. }
  416. #endif
  417. ctx=(classify_ctx *)_ctx;
  418. image_data_pre_block(&ctx->img,_data,_stride,_bi,_bj);
  419. }
  420. static void ip_fdct_block(void *_ctx,const unsigned char *_data,int _stride,
  421. int _bi,int _bj){
  422. classify_ctx *ctx;
  423. (void)_data;
  424. (void)_stride;
  425. #if PRINT_PROGRESS
  426. if(_bi==0&&_bj==0){
  427. print_progress(stdout,"ip_fdct_block");
  428. }
  429. #endif
  430. ctx=(classify_ctx *)_ctx;
  431. image_data_fdct_block(&ctx->img,_bi,_bj);
  432. }
  433. static void ip_add_block(void *_ctx,const unsigned char *_data,int _stride,
  434. int _bi,int _bj){
  435. classify_ctx *ctx;
  436. od_coeff *block;
  437. int by;
  438. int bx;
  439. int j;
  440. int i;
  441. double buf[5*B_SZ*B_SZ];
  442. int mode;
  443. double weight;
  444. (void)_data;
  445. (void)_stride;
  446. #if PRINT_PROGRESS
  447. if(_bi==0&&_bj==0){
  448. print_progress(stdout,"ip_add_block");
  449. }
  450. #endif
  451. ctx=(classify_ctx *)_ctx;
  452. for(by=0;by<=1;by++){
  453. for(bx=0;bx<=2-by;bx++){
  454. block=&ctx->img.fdct[ctx->img.fdct_stride*B_SZ*(_bj+by)+B_SZ*(_bi+bx)];
  455. for(j=0;j<B_SZ;j++){
  456. for(i=0;i<B_SZ;i++){
  457. buf[B_SZ*B_SZ*(3*by+bx)+j*B_SZ+i]=block[ctx->img.fdct_stride*j+i];
  458. }
  459. }
  460. }
  461. }
  462. #if SUBTRACT_DC
  463. for(i=0;i<4;i++){
  464. buf[4*B_SZ*B_SZ]-=0.25*buf[i*B_SZ*B_SZ];
  465. }
  466. #endif
  467. mode=ctx->img.mode[ctx->img.nxblocks*_bj+_bi];
  468. weight=ctx->img.weight[ctx->img.nxblocks*_bj+_bi];
  469. od_covmat_add(&ctx->pd[mode],buf,weight);
  470. }
  471. #if PRINT_BLOCKS
  472. static void ip_print_block(void *_ctx,const unsigned char *_data,int _stride,
  473. int _bi,int _bj){
  474. classify_ctx *ctx;
  475. #if PRINT_PROGRESS
  476. if(_bi==0&&_bj==0){
  477. print_progress(stdout,"ip_print_block");
  478. }
  479. #endif
  480. ctx=(classify_ctx *)_ctx;
  481. image_data_print_block(&ctx->img,_bi,_bj,stderr);
  482. }
  483. #endif
  484. static void ip_pred_block(void *_ctx,const unsigned char *_data,int _stride,
  485. int _bi,int _bj){
  486. classify_ctx *ctx;
  487. (void)_data;
  488. (void)_stride;
  489. #if PRINT_PROGRESS
  490. if(_bi==0&&_bj==0){
  491. print_progress("ip_pred_block");
  492. }
  493. #endif
  494. ctx=(classify_ctx *)_ctx;
  495. image_data_pred_block(&ctx->img,_bi,_bj);
  496. }
  497. static void ip_idct_block(void *_ctx,const unsigned char *_data,int _stride,
  498. int _bi,int _bj){
  499. classify_ctx *ctx;
  500. (void)_data;
  501. (void)_stride;
  502. #if PRINT_PROGRESS
  503. if(_bi==0&&_bj==0){
  504. print_progress("ip_idct_block");
  505. }
  506. #endif
  507. ctx=(classify_ctx *)_ctx;
  508. image_data_idct_block(&ctx->img,_bi,_bj);
  509. }
  510. static void ip_post_block(void *_ctx,const unsigned char *_data,int _stride,
  511. int _bi,int _bj){
  512. classify_ctx *ctx;
  513. (void)_data;
  514. (void)_stride;
  515. #if PRINT_PROGRESS
  516. if(_bi==0&&_bj==0){
  517. print_progress("ip_post_block");
  518. }
  519. #endif
  520. ctx=(classify_ctx *)_ctx;
  521. image_data_post_block(&ctx->img,_bi,_bj);
  522. }
  523. double b[OD_INTRA_NMODES][B_SZ*B_SZ];
  524. static void ip_stats_block(void *_ctx,const unsigned char *_data,int _stride,
  525. int _bi,int _bj){
  526. classify_ctx *ctx;
  527. #if PRINT_PROGRESS
  528. if(_bi==0&&_bj==0){
  529. print_progress("ip_stats_block");
  530. }
  531. #endif
  532. ctx=(classify_ctx *)_ctx;
  533. image_data_stats_block(&ctx->img,_data,_stride,_bi,_bj,&ctx->st);
  534. {
  535. od_coeff *block;
  536. double *pred;
  537. int mode;
  538. int j;
  539. int i;
  540. block=&ctx->img.fdct[ctx->img.fdct_stride*B_SZ*(_bj+1)+B_SZ*(_bi+1)];
  541. pred=&ctx->img.pred[ctx->img.pred_stride*B_SZ*_bj+B_SZ*_bi];
  542. mode=ctx->img.mode[ctx->img.nxblocks*_bj+_bi];
  543. for(j=0;j<B_SZ;j++){
  544. for(i=0;i<B_SZ;i++){
  545. double res;
  546. res=sqrt(OD_SCALE[j]*OD_SCALE[i])*abs(block[ctx->img.fdct_stride*j+i]-(od_coeff)floor(pred[ctx->img.pred_stride*j+i]+0.5));
  547. ctx->bits+=1+OD_LOG2(b[mode][j*B_SZ+i])+M_LOG2E/b[mode][j*B_SZ+i]*res;
  548. }
  549. }
  550. }
  551. }
  552. #if WRITE_IMAGES
  553. static void ip_files_block(void *_ctx,const unsigned char *_data,int _stride,
  554. int _bi,int _bj){
  555. classify_ctx *ctx;
  556. #if PRINT_PROGRESS
  557. if(_bi==0&&_bj==0){
  558. print_progress(stdout,"ip_files_block");
  559. }
  560. #endif
  561. ctx=(classify_ctx *)_ctx;
  562. image_data_files_block(&ctx->img,_data,_stride,_bi,_bj,&ctx->files);
  563. }
  564. #endif
  565. int step;
  566. static void vp8_mode_block(void *_ctx,const unsigned char *_data,int _stride,
  567. int _bi,int _bj){
  568. classify_ctx *ctx;
  569. unsigned char *mode;
  570. double *weight;
  571. #if PRINT_PROGRESS
  572. if(_bi==0&&_bj==0){
  573. print_progress(stdout,"ip_vp8_mode_block");
  574. }
  575. #endif
  576. ctx=(classify_ctx *)_ctx;
  577. mode=&ctx->img.mode[ctx->img.nxblocks*_bj+_bi];
  578. weight=&ctx->img.weight[ctx->img.nxblocks*_bj+_bi];
  579. *mode=vp8_select_mode(_data,_stride,weight);
  580. #if USE_WEIGHTS
  581. if(*mode==0){
  582. *weight=1;
  583. }
  584. #else
  585. *weight=1;
  586. #endif
  587. }
  588. static void od_mode_block(void *_ctx,const unsigned char *_data,int _stride,
  589. int _bi,int _bj){
  590. classify_ctx *ctx;
  591. unsigned char *mode;
  592. od_coeff *block;
  593. double *weight;
  594. (void)_data;
  595. (void)_stride;
  596. #if PRINT_PROGRESS
  597. if(_bi==0&&_bj==0){
  598. print_progress("od_mode_block");
  599. }
  600. #endif
  601. ctx=(classify_ctx *)_ctx;
  602. mode=&ctx->img.mode[ctx->img.nxblocks*_bj+_bi];
  603. block=&ctx->img.fdct[ctx->img.fdct_stride*B_SZ*(_bj+1)+B_SZ*(_bi+1)];
  604. weight=&ctx->img.weight[ctx->img.nxblocks*_bj+_bi];
  605. #if BITS_SELECT
  606. if(step==1){
  607. *mode=od_select_mode_satd(block,ctx->img.fdct_stride,weight);
  608. }
  609. else{
  610. *mode=od_select_mode_bits(block,ctx->img.fdct_stride,weight,b);
  611. }
  612. #else
  613. *mode=od_select_mode_satd(block,ctx->img.fdct_stride,weight);
  614. #endif
  615. #if USE_WEIGHTS
  616. if(*mode==0){
  617. *weight=1;
  618. }
  619. #else
  620. *weight=1;
  621. #endif
  622. }
  623. static int init_start(void *_ctx,const char *_name,const th_info *_ti,int _pli,
  624. int _nxblocks,int _nyblocks){
  625. classify_ctx *ctx;
  626. (void)_ti;
  627. (void)_pli;
  628. #if PRINT_PROGRESS
  629. print_progress(stdout,"init_start");
  630. #endif
  631. fprintf(stdout,"%s\n",_name);
  632. fflush(stdout);
  633. ctx=(classify_ctx *)_ctx;
  634. classify_ctx_set_image(ctx,_name,_nxblocks,_nyblocks);
  635. return EXIT_SUCCESS;
  636. }
  637. static int init_finish(void *_ctx){
  638. classify_ctx *ctx;
  639. #if PRINT_PROGRESS
  640. print_progress(stdout,"init_finish");
  641. #endif
  642. ctx=(classify_ctx *)_ctx;
  643. /*intra_stats_combine(&ctx->gb,&ctx->st);
  644. intra_stats_correct(&ctx->st);
  645. fprintf(stdout,"%s\n",ctx->img.name);
  646. intra_stats_print(&ctx->st,"Daala Intra Predictors",OD_SCALE);
  647. fflush(stdout);*/
  648. image_data_save_map(&ctx->img);
  649. classify_ctx_clear_image(ctx);
  650. return EXIT_SUCCESS;
  651. }
  652. const block_func INIT[]={
  653. ip_pre_block,
  654. ip_fdct_block,
  655. vp8_mode_block,
  656. ip_add_block,
  657. #if PRINT_BLOCKS
  658. ip_print_block,
  659. #endif
  660. };
  661. const int NINIT=sizeof(INIT)/sizeof(*INIT);
  662. static int pred_start(void *_ctx,const char *_name,const th_info *_ti,int _pli,
  663. int _nxblocks,int _nyblocks){
  664. classify_ctx *ctx;
  665. (void)_ti;
  666. (void)_pli;
  667. #if PRINT_PROGRESS
  668. print_progress(stdout,"pred_start");
  669. #endif
  670. ctx=(classify_ctx *)_ctx;
  671. classify_ctx_set_image(ctx,_name,_nxblocks,_nyblocks);
  672. image_data_load_map(&ctx->img);
  673. return EXIT_SUCCESS;
  674. }
  675. static int pred_finish(void *_ctx){
  676. classify_ctx *ctx;
  677. #if WRITE_IMAGES
  678. char suffix[16];
  679. #endif
  680. #if PRINT_PROGRESS
  681. print_progress(stdout,"pred_finish");
  682. #endif
  683. ctx=(classify_ctx *)_ctx;
  684. intra_stats_combine(&ctx->gb,&ctx->st);
  685. intra_stats_correct(&ctx->st);
  686. fprintf(stdout,"%s\n",ctx->img.name);
  687. intra_stats_print(&ctx->st,"Daala Intra Predictors",OD_SCALE);
  688. fflush(stdout);
  689. #if WRITE_IMAGES
  690. sprintf(suffix,"-step%02i",step);
  691. image_files_write(&ctx->files,ctx->img.name,suffix);
  692. #endif
  693. image_data_save_map(&ctx->img);
  694. classify_ctx_clear_image(ctx);
  695. return EXIT_SUCCESS;
  696. }
  697. const block_func PRED[]={
  698. ip_pre_block,
  699. ip_fdct_block,
  700. od_mode_block,
  701. ip_add_block,
  702. ip_pred_block,
  703. ip_stats_block,
  704. ip_idct_block,
  705. ip_post_block,
  706. #if WRITE_IMAGES
  707. ip_files_block,
  708. #endif
  709. };
  710. const int NPRED=sizeof(PRED)/sizeof(*PRED);
  711. #define PADDING (4*B_SZ)
  712. #if PADDING<3*B_SZ
  713. # error "PADDING must be at least 3*B_SZ"
  714. #endif
  715. #define INIT_STEPS (10)
  716. #if B_SZ==4
  717. # define DROP_STEPS (60)
  718. # define DROPS_PER_STEP (16)
  719. #elif B_SZ==8
  720. # define DROP_STEPS (126)
  721. # define DROPS_PER_STEP (128)
  722. #elif B_SZ==16
  723. # define DROP_STEPS (255)
  724. # define DROPS_PER_STEP (1024)
  725. #else
  726. # error "Unsupported block size."
  727. #endif
  728. int main(int _argc,const char *_argv[]){
  729. classify_ctx cls[NUM_PROCS];
  730. int i;
  731. int j;
  732. ne_filter_params_init();
  733. vp8_scale_init(VP8_SCALE);
  734. od_scale_init(OD_SCALE);
  735. #if WRITE_IMAGES
  736. intra_map_colors(COLORS,OD_INTRA_NMODES);
  737. #endif
  738. for(i=0;i<NUM_PROCS;i++){
  739. classify_ctx_init(&cls[i]);
  740. }
  741. omp_set_num_threads(NUM_PROCS);
  742. /* First pass across images uses VP8 mode selection. */
  743. ne_apply_to_blocks(cls,sizeof(*cls),0x1,PADDING,init_start,NINIT,INIT,
  744. init_finish,_argc,_argv);
  745. for(i=1;i<NUM_PROCS;i++){
  746. cls[0].n+=cls[i].n;
  747. }
  748. if(cls[0].n>0){
  749. prob_ctx prob;
  750. solve_ctx sol[NUM_PROCS];
  751. od_covmat ete;
  752. int mask[OD_INTRA_NMODES][B_SZ*B_SZ*4*B_SZ*B_SZ];
  753. struct timeb start;
  754. struct timeb stop;
  755. prob_ctx_init(&prob);
  756. for(i=0;i<NUM_PROCS;i++){
  757. solve_ctx_init(&sol[i]);
  758. }
  759. od_covmat_init(&ete,B_SZ*B_SZ);
  760. for(i=0;i<OD_INTRA_NMODES;i++){
  761. for(j=0;j<B_SZ*B_SZ*4*B_SZ*B_SZ;j++){
  762. mask[i][j]=1;
  763. }
  764. }
  765. ftime(&start);
  766. /* Each k-means step uses Daala mode selection. */
  767. for(step=1;step<=INIT_STEPS+DROP_STEPS;step++){
  768. int mults;
  769. int drops;
  770. mults=B_SZ*B_SZ*4*B_SZ*B_SZ;
  771. drops=0;
  772. if(step>INIT_STEPS){
  773. mults-=DROPS_PER_STEP*(step-INIT_STEPS);
  774. drops=DROPS_PER_STEP;
  775. }
  776. printf("Starting Step %02i (%i mults / block)\n",step,mults);
  777. for(j=0;j<OD_INTRA_NMODES;j++){
  778. /* Combine the gathered prediction data. */
  779. for(i=1;i<NUM_PROCS;i++){
  780. od_covmat_combine(&cls[0].pd[j],&cls[i].pd[j]);
  781. }
  782. prob_ctx_load(&prob,&cls[0].pd[j]);
  783. /* Update predictor model based on mults and drops. */
  784. #if PRINT_DROPS
  785. if(drops>0){
  786. printf("Mode %i\n",j);
  787. fflush(stdout);
  788. }
  789. #endif
  790. comp_predictors(&prob,sol,drops,mask[j]);
  791. /* Compute residual covariance for each mode. */
  792. prob_ctx_comp_error(&prob,&cls[0].pd[j],sol->beta_1);
  793. #if ZERO_MEAN
  794. {
  795. double mean[B_SZ*B_SZ];
  796. for(i=0;i<B_SZ*B_SZ;i++){
  797. mean[i]=0;
  798. }
  799. od_covmat_update(&ete,prob.ete,mean,cls[0].pd[j].w);
  800. }
  801. #else
  802. od_covmat_update(&ete,prob.ete,sol->beta_0,cls[0].pd[j].w);
  803. #endif
  804. #if !POOLED_COV
  805. od_covmat_correct(&ete);
  806. update_diversity(ete.cov,b[j],OD_SCALE);
  807. od_covmat_reset(&ete);
  808. #endif
  809. #if SUBTRACT_DC
  810. for(i=0;i<4;i++){
  811. OD_ASSERT(mask[j][i*B_SZ*B_SZ]);
  812. sol->beta_1[i*B_SZ*B_SZ]+=0.25;
  813. }
  814. #endif
  815. update_predictors(j,sol->beta_0,sol->beta_1,mask[j]);
  816. }
  817. #if POOLED_COV
  818. od_covmat_correct(&ete);
  819. for(j=0;j<OD_INTRA_NMODES;j++){
  820. update_diversity(ete.cov,b[j],OD_SCALE);
  821. }
  822. od_covmat_reset(&ete);
  823. #endif
  824. /* Reset the prediction data. */
  825. for(i=0;i<NUM_PROCS;i++){
  826. classify_ctx_reset(&cls[i]);
  827. }
  828. /* Reclassify based on the new model. */
  829. ne_apply_to_blocks(cls,sizeof(*cls),0x1,PADDING,pred_start,NPRED,PRED,
  830. pred_finish,_argc,_argv);
  831. ftime(&stop);
  832. printf("Finished Step %02i (%lims)\n",step,timing(&start,&stop));
  833. start=stop;
  834. /* Combine the gathered intra stats. */
  835. for(i=1;i<NUM_PROCS;i++){
  836. intra_stats_combine(&cls[0].gb,&cls[i].gb);
  837. cls[0].bits+=cls[i].bits;
  838. }
  839. printf("Step %02i Total Bits %-24.18G\n",step,cls[0].bits);
  840. intra_stats_correct(&cls[0].gb);
  841. intra_stats_print(&cls[0].gb,"Daala Intra Predictors",OD_SCALE);
  842. }
  843. prob_ctx_clear(&prob);
  844. for(i=0;i<NUM_PROCS;i++){
  845. solve_ctx_clear(&sol[i]);
  846. }
  847. od_covmat_clear(&ete);
  848. #if PRINT_BETAS
  849. print_predictors(stderr);
  850. #endif
  851. }
  852. for(i=0;i<NUM_PROCS;i++){
  853. classify_ctx_clear(&cls[i]);
  854. }
  855. return EXIT_SUCCESS;
  856. }