patchpredict.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import numpy as np
  2. import PIL.Image as Image
  3. import patchreconst as pr
  4. import matplotlib.pyplot as plt
  5. PATCH_SIZE = 32
  6. COMPONENT_COUNT = 100
  7. PATCH_COUNT = 400000
  8. EPOCH_COUNT = 16
  9. CODE_COUNT = 800
  10. SPARSITY_PARAMETER = 4.0
  11. NON_NEGATIVE = True
  12. scBasisFile = 'basis/basisnn' + str(NON_NEGATIVE) + 'lamb' + str(SPARSITY_PARAMETER) + 'comps' + str(COMPONENT_COUNT) + 'codes' + str(CODE_COUNT) + 'patches' + str(PATCH_COUNT) + 'epochs' + str(EPOCH_COUNT) + '.npy'
  13. scBasis = np.load(scBasisFile)
  14. icaFilterFile = 'icabasis/basisluisicacomps' + str(COMPONENT_COUNT) + 'codes' + str(CODE_COUNT) + 'patches' + str(PATCH_COUNT) + '.npy'
  15. icaFilters = np.load(icaFilterFile)
  16. patches = np.load('patches_1000_mod.npy')
  17. patchesOrig = np.load('patches_1000.npy')[:, :, :, 0]
  18. rmBounds = np.load('patches_1000_rmbounds.npy')
  19. rmBounds = (rmBounds[0], rmBounds[1], rmBounds[2])
  20. scCodes, icaCodes, v1Simple, v1cMean, angles = pr.responses(patches, scBasis, SPARSITY_PARAMETER, NON_NEGATIVE, icaFilters)
  21. pcaTransformed, v1C = pr.responsesPCAV1C(patches)
  22. patchesReconstV1 = pr.reconstructV1(v1Simple)
  23. patchesReconstPCA = pr.reconstructPCA(pcaTransformed, v1cMean, angles)
  24. patchesReconstSC = pr.reconstruct(scBasis, scCodes, v1cMean, angles)
  25. patchesReconstICA = pr.reconstruct(icaFilters, icaCodes, v1cMean, angles)
  26. '''
  27. for i in range(patchesReconstV1.shape[0]):
  28. f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharey = True)
  29. ax1.imshow(patchesOrig[i], cmap = 'gray', interpolation = 'none')
  30. ax1.title.set_text('Original patch')
  31. ax2.imshow(patches[i], cmap = 'gray', interpolation = 'none')
  32. ax2.title.set_text('Modified patch')
  33. ax3.imshow(patchesReconstV1[i], cmap = 'gray', interpolation = 'none')
  34. ax3.title.set_text('V1')
  35. ax4.imshow(patchesReconstSC[i], cmap = 'gray', interpolation = 'none')
  36. ax4.title.set_text('SC')
  37. plt.savefig('pred/' + str(i) + '.png')
  38. plt.close()
  39. '''