sparsify.m 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. function [f] = sparsify(r)
  2. function [pgmi] = pgUpdate(Cn,Dn,E,F,SX,SY,m,i),
  3. ind=m(i,:)==1;
  4. Bni=Cn(ind,ind)\Dn(ind,i);
  5. B1i=SX(ind,ind)*Bni*SY(i,i)^-1;
  6. pgi=10*log10(F(i,i)/(F(i,i)-E(i,ind)*B1i));
  7. pgmi=zeros(1,64);
  8. for j=find(m(i,:)==1),
  9. m(i,j)=0;
  10. ind=m(i,:)==1;
  11. Bni=Cn(ind,ind)\Dn(ind,i);
  12. B1i=SX(ind,ind)*Bni*SY(i,i)^-1;
  13. pgmi(j)=pgi-10*log10(F(i,i)/(F(i,i)-E(i,ind)*B1i));
  14. m(i,j)=1;
  15. end
  16. end
  17. function [m,B1,B0] = fitModel(s,X,Y,m,d)
  18. Xm=mean(X);
  19. Xz=bsxfun(@minus,X,Xm);
  20. Ym=mean(Y);
  21. Yz=bsxfun(@minus,Y,Ym);
  22. C=Xz'*Xz;
  23. D=Xz'*Yz;
  24. E=Yz'*Xz;
  25. F=Yz'*Yz;
  26. SX=diag(1./diag(C).^(1/2));
  27. SY=diag(1./diag(F).^(1/2));
  28. Cn=SX*C*SX;
  29. Dn=SX*D*SY;
  30. pgm=zeros(16,64);
  31. % build the Pg masked look-up table
  32. for i=1:16,
  33. pgm(i,:)=pgUpdate(Cn,Dn,E,F,SX,SY,m,i);
  34. end
  35. % drop the top-d coefficients with the least impact on Pg
  36. for k=1:d,
  37. [pg,i]=min(pgm(m(:)==1));
  38. [i,j]=ind2sub(size(pgm),find(m(:)==1)(i));
  39. fprintf('Dropping (%i,%i) with pg=%g\n',i,j,pg)
  40. fflush(stdout);
  41. m(i,j)=0;
  42. pgm(i,:)=pgUpdate(Cn,Dn,E,F,SX,SY,m,i);
  43. end
  44. B1=zeros(64,16);
  45. B0=zeros(1,16);
  46. for i=1:16,
  47. ind=m(i,:)==1;
  48. Bni=Cn(ind,ind)\Dn(ind,i);
  49. B1(ind,i)=SX(ind,ind)*Bni*SY(i,i)^-1;
  50. B0(i)=Ym(i)-Xm(ind)*B1(ind,i);
  51. end
  52. fprintf('%s Blocks %i Pg=%g\n',s,size(X,1),sum(10*log10(diag(F)./diag(F-E*B1)))/16);
  53. fflush(stdout);
  54. end
  55. % reclassify based on weighted SATD
  56. function [c] = reclassify(X,Y,B1,B0)
  57. global OD_SCALE=[0.687397666275501251,0.691608410328626633,0.877750061452388763,0.874039031565189362];
  58. global scale=reshape(OD_SCALE'*OD_SCALE,1,16);
  59. fprintf('Reclassifying');
  60. fflush(stdout);
  61. for i=1:size(X,1),
  62. wSATD=zeros(10,1);
  63. if mod(i,10000)==0,
  64. fprintf('.');
  65. fflush(stdout);
  66. end
  67. for j=1:10,
  68. wSATD(j)=sum(abs(Y(i,:)-(B0(j,:)+X(i,:)*squeeze(B1(j,:,:)))).*scale);
  69. end
  70. [v,c(i)]=min(wSATD);
  71. end
  72. fprintf('\n');
  73. fflush(stdout);
  74. end
  75. function [] = printMode(X,Y,i,B1,B0)
  76. Ym=mean(Y);
  77. Yz=bsxfun(@minus,Y,Ym);
  78. F=Yz'*Yz;
  79. E=Y-bsxfun(@plus,X*B1,B0);
  80. Em=mean(E);
  81. Ez=bsxfun(@minus,E,Em);
  82. G=Ez'*Ez;
  83. fprintf('Mode %i Blocks %i Pg=%g\n',i-1,size(X,1),sum(10*log10(diag(F)./diag(G))/16));
  84. fflush(stdout);
  85. end
  86. %
  87. function [c,m,B1,B0] = kStep(c,X,Y,m,s,d)
  88. B1=zeros(10,64,16);
  89. B0=zeros(10,16);
  90. fprintf('Step %i (%i mults / block)\n',s,sum(m(:))/10-d);
  91. fflush(stdout);
  92. % fit the prediction model for 10 modes
  93. for i=1:10,
  94. ind=find(c(:)==i);
  95. [m(i,:,:),B1(i,:,:),B0(i,:)]=fitModel(sprintf('Mode %i',i-1),X(ind,:),Y(ind,:),squeeze(m(i,:,:)),d);
  96. end
  97. c=reclassify(X,Y,B1,B0);
  98. % print mode
  99. for i=1:10,
  100. ind=find(c(:)==i);
  101. printMode(X(ind,:),Y(ind,:),i,squeeze(B1(i,:,:)),B0(i,:))
  102. end
  103. end
  104. c=r(:,3).+1;
  105. X=r(:,4:67);
  106. Y=r(:,68:83);
  107. m=ones(10,16,64);
  108. B1=zeros(10,64,16);
  109. B0=zeros(10,16);
  110. % initial 10 steps of full k-means
  111. for s=1:10,
  112. [c,m,B1,B0]=kStep(c,X,Y,m,s,0);
  113. end
  114. % drop from 1024 multiplies per mode to 64
  115. for s=11:70,
  116. [c,m,B1,B0]=kStep(c,X,Y,m,s,16);
  117. end
  118. end