patchpredict2.py 984 B

12345678910111213141516171819202122232425
  1. import numpy as np
  2. import PIL.Image as Image
  3. import patchreconst as pr
  4. PATCH_SIZE = 32
  5. COMPONENT_COUNT = 100
  6. PATCH_COUNT = 400000
  7. EPOCH_COUNT = 16
  8. CODE_COUNT = 800
  9. SPARSITY_PARAMETER = 0.5
  10. NON_NEGATIVE = True
  11. scBasisFile = 'basis/basisnn' + str(NON_NEGATIVE) + 'lamb' + str(SPARSITY_PARAMETER) + 'comps' + str(COMPONENT_COUNT) + 'codes' + str(CODE_COUNT) + 'patches' + str(PATCH_COUNT) + 'epochs' + str(EPOCH_COUNT) + '.npy'
  12. scBasis = np.load(scBasisFile)
  13. icaFilterFile = 'icabasis/basisluisicacomps' + str(COMPONENT_COUNT) + 'codes' + str(CODE_COUNT) + 'patches' + str(PATCH_COUNT) + '.npy'
  14. icaFilters = np.load(icaFilterFile)
  15. patches = np.load('patches_1000.npy')[:, :, :, 0]
  16. scCodes, icaCodes, v1Simple, v1SimpleCropped, v1cMean, angles = pr.responsesCropV1(patches, scBasis, SPARSITY_PARAMETER, NON_NEGATIVE, icaFilters)
  17. patchesReconstSC = pr.reconstructForV1(scBasis, scCodes, v1cMean, angles)
  18. patchesReconstICA = pr.reconstructForV1(icaFilters, icaCodes, v1cMean, angles)