mse.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import numpy as np
  2. import PIL.Image as Image
  3. import patchreconst as pr
  4. PATCH_SIZE = 32
  5. COMPONENT_COUNT = 100
  6. PATCH_COUNT = 400000
  7. EPOCH_COUNT = 16
  8. CODE_COUNT = 800
  9. SPARSITY_PARAMETER = 4.0
  10. NON_NEGATIVE = True
  11. fileNamesFile = open('files.txt')
  12. fileNames = fileNamesFile.read().strip().split("\n")
  13. scBasisFile = 'basis/basisnn' + str(NON_NEGATIVE) + 'lamb' + str(SPARSITY_PARAMETER) + 'comps' + str(COMPONENT_COUNT) + 'codes' + str(CODE_COUNT) + 'patches' + str(PATCH_COUNT) + 'epochs' + str(EPOCH_COUNT) + '.npy'
  14. scBasis = np.load(scBasisFile)
  15. icaFilterFile = 'icabasis/basisluisicacomps' + str(COMPONENT_COUNT) + 'codes' + str(CODE_COUNT) + 'patches' + str(PATCH_COUNT) + '.npy'
  16. icaFilters = np.load(icaFilterFile)
  17. scMSE = 0.0
  18. icaMSE = 0.0
  19. v1MSE = 0.0
  20. for fileName in fileNames:
  21. image = np.load(fileName)
  22. imageWithNoise = np.load(fileName[:-4] + '_noise.npy')
  23. rowPad = 0
  24. if image.shape[0] % PATCH_SIZE != 0:
  25. lastPatchRows = image.shape[0] - (int(image.shape[0] / PATCH_SIZE) * PATCH_SIZE)
  26. rowPad = PATCH_SIZE - lastPatchRows
  27. if image.shape[1] % PATCH_SIZE != 0:
  28. lastPatchCols = image.shape[1] - (int(image.shape[1] / PATCH_SIZE) * PATCH_SIZE)
  29. colPad = PATCH_SIZE - lastPatchCols
  30. paddedImage = np.zeros((image.shape[0] + rowPad, image.shape[1] + colPad))
  31. paddedImage[:image.shape[0], :image.shape[1]] = image[:]
  32. paddedImage -= np.mean(paddedImage)
  33. paddedImage /= np.std(paddedImage)
  34. reconstV1 = np.zeros(paddedImage.shape)
  35. reconstSC = np.zeros(paddedImage.shape)
  36. reconstICA = np.zeros(paddedImage.shape)
  37. rows = int(paddedImage.shape[0] / PATCH_SIZE)
  38. cols = int(paddedImage.shape[1] / PATCH_SIZE)
  39. patches = np.zeros((rows, cols, PATCH_SIZE, PATCH_SIZE))
  40. for i in range(rows):
  41. for j in range(cols):
  42. row = i * PATCH_SIZE
  43. col = j * PATCH_SIZE
  44. rowEnd = row + PATCH_SIZE
  45. colEnd = col + PATCH_SIZE
  46. patches[i, j, :, :] = paddedImage[row:rowEnd, col:colEnd]
  47. patches = patches.reshape((-1, PATCH_SIZE, PATCH_SIZE))
  48. patches = patches.reshape((patches.shape[0], PATCH_SIZE * PATCH_SIZE))
  49. patches -= np.mean(patches, axis = -1)[:, np.newaxis]
  50. stds = np.std(patches, axis = -1)[:, np.newaxis]
  51. stds[np.where(stds == 0.0)] = 1.0
  52. patches /= stds
  53. patches = patches.reshape((-1, PATCH_SIZE, PATCH_SIZE))
  54. scCodes, icaCodes, v1Simple, v1cMean, angles = pr.responses(patches, scBasis, SPARSITY_PARAMETER, NON_NEGATIVE, icaFilters)
  55. patchesReconstV1 = pr.reconstructV1(v1Simple).reshape((rows, cols, PATCH_SIZE, PATCH_SIZE))
  56. patchesReconstSC = pr.reconstruct(scBasis, scCodes, v1cMean, angles).reshape((rows, cols, PATCH_SIZE, PATCH_SIZE))
  57. patchesReconstICA = pr.reconstruct(icaFilters, icaCodes, v1cMean, angles).reshape((rows, cols, PATCH_SIZE, PATCH_SIZE))
  58. for i in range(rows):
  59. for j in range(cols):
  60. row = i * PATCH_SIZE
  61. col = j * PATCH_SIZE
  62. rowEnd = row + PATCH_SIZE
  63. colEnd = col + PATCH_SIZE
  64. reconstV1[row:rowEnd, col:colEnd] = patchesReconstV1[i, j]
  65. reconstSC[row:rowEnd, col:colEnd] = patchesReconstSC[i, j]
  66. reconstICA[row:rowEnd, col:colEnd] = patchesReconstICA[i, j]
  67. reconstV1 = reconstV1[:image.shape[0], :image.shape[1]]
  68. reconstSC = reconstSC[:image.shape[0], :image.shape[1]]
  69. reconstICA = reconstICA[:image.shape[0], :image.shape[1]]
  70. pixelCount = float(image.shape[0] * image.shape[1])
  71. v1MSE += np.sum((reconstV1 - image) ** 2.0) / pixelCount
  72. scMSE += np.sum((reconstSC - image) ** 2.0) / pixelCount
  73. icaMSE += np.sum((reconstICA - image) ** 2.0) / pixelCount
  74. fileCount = float(len(fileNames))
  75. v1MSE /= fileCount
  76. scMSE /= fileCount
  77. icaMSE /= fileCount