load-dataset.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # load-dataset.py
  5. #
  6. # 2022 Stephen Stengel <stephen.stengel@cwu.edu>
  7. #
  8. # Script to load the dataset into file form for easy importing.
  9. #
  10. print("Loading imports...")
  11. import os
  12. import skimage
  13. from skimage.io import imsave
  14. from skimage.util import img_as_uint
  15. import shutil
  16. import matplotlib.pyplot as plt
  17. from pathlib import Path
  18. from tqdm import tqdm #Pretty loading bars
  19. import numpy as np
  20. import tensorflow as tf
  21. print("Done!")
  22. CLASS_INTERESTING = 0
  23. CLASS_NOT_INTERESTING = 1
  24. CLASS_INTERESTING_STRING = "interesting"
  25. CLASS_NOT_INTERESTING_STRING = "not"
  26. DATASET_DIRECTORY = os.path.normpath("./ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/614s/")
  27. INTERESTING_DIRECTORY = os.path.join(DATASET_DIRECTORY, "interesting")
  28. NOT_INTERESTING_DIRECTORY = os.path.join(DATASET_DIRECTORY, "not interesting")
  29. DATASET_COPY_FOLDER = os.path.normpath("./tmpdata/")
  30. DATASET_COPY_FOLDER_INT = os.path.join(DATASET_COPY_FOLDER, CLASS_INTERESTING_STRING)
  31. DATASET_COPY_FOLDER_NOT = os.path.join(DATASET_COPY_FOLDER, CLASS_NOT_INTERESTING_STRING)
  32. DATASET_PNG_FOLDER = os.path.normpath("./datasets-as-png/")
  33. DATASET_PNG_FOLDER_TRAIN = os.path.join(DATASET_PNG_FOLDER, "train")
  34. DATASET_PNG_FOLDER_TRAIN_INT = os.path.join(DATASET_PNG_FOLDER_TRAIN, CLASS_INTERESTING_STRING)
  35. DATASET_PNG_FOLDER_TRAIN_NOT = os.path.join(DATASET_PNG_FOLDER_TRAIN, CLASS_NOT_INTERESTING_STRING)
  36. DATASET_PNG_FOLDER_VAL = os.path.join(DATASET_PNG_FOLDER, "val")
  37. DATASET_PNG_FOLDER_VAL_INT = os.path.join(DATASET_PNG_FOLDER_VAL, CLASS_INTERESTING_STRING)
  38. DATASET_PNG_FOLDER_VAL_NOT = os.path.join(DATASET_PNG_FOLDER_VAL, CLASS_NOT_INTERESTING_STRING)
  39. DATASET_PNG_FOLDER_TEST = os.path.join(DATASET_PNG_FOLDER, "test")
  40. DATASET_PNG_FOLDER_TEST_INT = os.path.join(DATASET_PNG_FOLDER_TEST, CLASS_INTERESTING_STRING)
  41. DATASET_PNG_FOLDER_TEST_NOT = os.path.join(DATASET_PNG_FOLDER_TEST, CLASS_NOT_INTERESTING_STRING)
  42. DATASET_SAVE_DIR = os.path.normpath("./dataset/")
  43. TRAIN_SAVE_DIRECTORY = os.path.join(DATASET_SAVE_DIR, "train")
  44. VAL_SAVE_DIRECTORY = os.path.join(DATASET_SAVE_DIR, "val")
  45. TEST_SAVE_DIRECTORY = os.path.join(DATASET_SAVE_DIR, "test")
  46. ALL_FOLDERS_LIST = [
  47. DATASET_DIRECTORY,
  48. INTERESTING_DIRECTORY,
  49. NOT_INTERESTING_DIRECTORY,
  50. DATASET_COPY_FOLDER,
  51. DATASET_COPY_FOLDER_INT,
  52. DATASET_COPY_FOLDER_NOT,
  53. DATASET_PNG_FOLDER,
  54. DATASET_PNG_FOLDER_TRAIN,
  55. DATASET_PNG_FOLDER_TRAIN_INT,
  56. DATASET_PNG_FOLDER_TRAIN_NOT,
  57. DATASET_PNG_FOLDER_VAL,
  58. DATASET_PNG_FOLDER_VAL_INT,
  59. DATASET_PNG_FOLDER_VAL_NOT,
  60. DATASET_PNG_FOLDER_TEST,
  61. DATASET_PNG_FOLDER_TEST_INT,
  62. DATASET_PNG_FOLDER_TEST_NOT,
  63. DATASET_SAVE_DIR,
  64. TRAIN_SAVE_DIRECTORY,
  65. VAL_SAVE_DIRECTORY,
  66. TEST_SAVE_DIRECTORY
  67. ]
  68. HIDDEN_DOWNLOAD_FLAG_FILE = ".isnotfirstdownload"
  69. CLASS_NAMES_LIST_INT = [CLASS_INTERESTING, CLASS_NOT_INTERESTING]
  70. CLASS_NAMES_LIST_STR = [CLASS_INTERESTING_STRING, CLASS_NOT_INTERESTING_STRING]
  71. TEST_PRINTING = False
  72. IS_SAVE_THE_DATASETS = True
  73. IS_SAVE_THE_PNGS = True
  74. IS_DOWNLOAD_PICTURES = False
  75. def main(args):
  76. print("Hello! This is the Animal Crossing Dataset Loader!")
  77. makeDirectories(ALL_FOLDERS_LIST)
  78. checkArgs(args)
  79. print("DATASET_DIRECTORY: " + str(DATASET_DIRECTORY))
  80. print("Creating file structure...")
  81. createFileStructure(INTERESTING_DIRECTORY, DATASET_COPY_FOLDER_INT)
  82. createFileStructure(NOT_INTERESTING_DIRECTORY, DATASET_COPY_FOLDER_NOT)
  83. print("Done!")
  84. interestingFNames = getListOfAnimalPicsInOneClass(DATASET_COPY_FOLDER_INT)
  85. notInterestingFNames = getListOfAnimalPicsInOneClass(DATASET_COPY_FOLDER_NOT)
  86. #These WILL change later
  87. # ~ img_width = 400
  88. # ~ img_height = 300
  89. img_width = 100
  90. img_height = 100
  91. batch_size = 32
  92. percentageTrain = 0.6
  93. percentageTestToVal = 0.75
  94. print("creating the datasets...")
  95. train_ds, val_ds, test_ds = createAnimalsDataset(
  96. DATASET_COPY_FOLDER, img_height, img_width, batch_size, percentageTrain, percentageTestToVal)
  97. print("Done!")
  98. print("Saving datasets...")
  99. if IS_SAVE_THE_DATASETS:
  100. saveDatasets(
  101. train_ds, TRAIN_SAVE_DIRECTORY,
  102. val_ds, VAL_SAVE_DIRECTORY,
  103. test_ds, TEST_SAVE_DIRECTORY)
  104. print("Done!")
  105. else:
  106. print("Saving disabled for now!")
  107. if IS_SAVE_THE_PNGS:
  108. print("Saving the datasets as image files...")
  109. saveDatasetAsPNG(train_ds, DATASET_PNG_FOLDER_TRAIN)
  110. saveDatasetAsPNG(val_ds, DATASET_PNG_FOLDER_VAL)
  111. saveDatasetAsPNG(test_ds, DATASET_PNG_FOLDER_TEST)
  112. else:
  113. print("PNG saving disabled for now!")
  114. print("Deleting the temporary image folder...")
  115. shutil.rmtree(DATASET_COPY_FOLDER)
  116. os.sync()
  117. print("Done!")
  118. return 0
  119. # Creates the necessary directories.
  120. def makeDirectories(listOfFoldersToCreate):
  121. #Clear out old files -- Justin Case
  122. if os.path.isdir(DATASET_SAVE_DIR):
  123. shutil.rmtree(DATASET_SAVE_DIR, ignore_errors = True)
  124. if os.path.isdir(DATASET_PNG_FOLDER):
  125. shutil.rmtree(DATASET_PNG_FOLDER, ignore_errors = True)
  126. if os.path.isdir(DATASET_COPY_FOLDER):
  127. shutil.rmtree(DATASET_COPY_FOLDER, ignore_errors = True)
  128. for folder in listOfFoldersToCreate:
  129. if not os.path.isdir(folder):
  130. os.makedirs(folder)
  131. # Retrieves the images if they're not here
  132. def retrieveImages():
  133. print("Retrieving images...")
  134. os.system("wget -e robots=off -r -np --mirror https://ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/")
  135. print("Done!")
  136. #Checks if a flag file is in place to determine if the dataset should download from the ftp server.
  137. def isDownloadedFlagFileSet():
  138. if not os.path.isfile(HIDDEN_DOWNLOAD_FLAG_FILE):
  139. Path(HIDDEN_DOWNLOAD_FLAG_FILE).touch(exist_ok=True)
  140. return False
  141. return True
  142. #Takes some images from the validation set and sets the aside for the test set.
  143. def createTestSet(val_ds, percentageTestToVal):
  144. length = np.asarray(val_ds.cardinality())
  145. numForTest = int(length * percentageTestToVal)
  146. test_dataset = val_ds.take(numForTest)
  147. val_ds = val_ds.skip(numForTest)
  148. return val_ds, test_dataset
  149. def saveDatasets(train_ds, trainDir, val_ds, valDir, test_ds, testDir):
  150. tf.data.experimental.save(train_ds, trainDir)
  151. tf.data.experimental.save(val_ds, valDir)
  152. tf.data.experimental.save(test_ds, testDir)
  153. #The batching makes them get stuck together in batches. Right now that's 32 images.
  154. #So whenever you take one from the set, you get a batch of 32 images.
  155. # percentageTrain is a decimal from 0 to 1 of the percent data that should be for train
  156. # percentageTestToVal is a number from 0 to 1 of the percentage of the non-train data for use as test
  157. def createAnimalsDataset(baseDirectory, img_height, img_width, batch_size, percentageTrain, percentageTestToVal):
  158. valSplit = 1 - percentageTrain
  159. train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  160. baseDirectory,
  161. labels = "inferred",
  162. label_mode = "int",
  163. class_names = CLASS_NAMES_LIST_STR, #must match directory names
  164. color_mode = "grayscale",
  165. validation_split = valSplit,
  166. subset="training",
  167. seed=123,
  168. image_size=(img_height, img_width),
  169. batch_size=batch_size)
  170. val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  171. baseDirectory,
  172. labels = "inferred",
  173. label_mode = "int",
  174. class_names = CLASS_NAMES_LIST_STR, #must match directory names
  175. color_mode = "grayscale",
  176. validation_split = valSplit,
  177. subset="validation",
  178. seed=123,
  179. image_size=(img_height, img_width),
  180. batch_size=batch_size)
  181. val_ds, test_ds = createTestSet(val_ds, percentageTestToVal)
  182. AUTOTUNE = tf.data.AUTOTUNE
  183. normalization_layer = tf.keras.layers.Rescaling(1./255) #for newer versions of tensorflow
  184. # ~ normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255) #for old versions
  185. train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
  186. val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
  187. test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
  188. flippyBoy = tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal")
  189. train_ds = train_ds.map(lambda x, y: (flippyBoy(x), y), num_parallel_calls=AUTOTUNE)
  190. val_ds = val_ds.map(lambda x, y: (flippyBoy(x), y), num_parallel_calls=AUTOTUNE)
  191. myRotate = tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
  192. train_ds = train_ds.map(lambda x, y: (myRotate(x), y), num_parallel_calls=AUTOTUNE)
  193. val_ds = val_ds.map(lambda x, y: (myRotate(x), y), num_parallel_calls=AUTOTUNE)
  194. train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
  195. val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
  196. test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
  197. if TEST_PRINTING:
  198. print("Showing some unaltered images from the testing set...")
  199. printSample(test_ds)
  200. if TEST_PRINTING:
  201. print("Showing some augmented images from the training set...")
  202. printSample(train_ds)
  203. return train_ds, val_ds, test_ds
  204. # Prints first nine images from the first batch of the dataset.
  205. # It's random as long as you shuffle the dataset! ;)
  206. def printSample(in_ds):
  207. plt.figure(figsize=(10, 10))
  208. for img, label in in_ds.take(1):
  209. for i in tqdm(range(9)):
  210. ax = plt.subplot(3, 3, i + 1)
  211. myImg = np.asarray(img)
  212. plt.imshow(np.asarray(myImg[i]), cmap="gray")
  213. plt.title( CLASS_NAMES_LIST_STR[ np.asarray(label[i]) ] )
  214. plt.axis("off")
  215. plt.show()
  216. #save all images from dataset to file as png
  217. def saveDatasetAsPNG(in_ds, saveFolder):
  218. i = 0
  219. for batch in tqdm(in_ds):
  220. imgArr = np.asarray(batch[0])
  221. labelArr = np.asarray(batch[1])
  222. for j in range(len(imgArr)):
  223. thisImg = imgArr[j]
  224. thisImg = img_as_uint(thisImg)
  225. thisLabel = labelArr[j]
  226. filenamestring = os.path.join(saveFolder, CLASS_NAMES_LIST_STR[thisLabel], str(i) + ".png")
  227. imsave(filenamestring, thisImg)
  228. i = i + 1
  229. def createFileStructure(baseDirSource, destination):
  230. recursivelyCopyAllFilesInFolderToOneDestinationFolder(baseDirSource, destination)
  231. def recursivelyCopyAllFilesInFolderToOneDestinationFolder(baseDirSource, destination):
  232. print("Copying files to " + str(destination))
  233. cpyFiles = getListOfFilenames(baseDirSource)
  234. for thisName in tqdm(cpyFiles):
  235. try:
  236. shutil.copy(thisName, destination)
  237. except:
  238. print("copy skipping: " + str(thisName))
  239. def getListOfAnimalPicsInOneClass(classDir):
  240. dirNames = getListOfDirNames(classDir)
  241. picNames = []
  242. for dName in dirNames:
  243. picNames.extend( getCuratedListOfFileNames(dName) )
  244. return picNames
  245. def getCuratedListOfFileNames(directoryName):
  246. thisNames = getListOfFilenames(directoryName)
  247. thisNames = keepOnlyJPG(thisNames)
  248. return thisNames
  249. def keepOnlyJPG(inList):
  250. for thingy in inList:
  251. pathParts = os.path.splitext(thingy)
  252. if pathParts[-1].lower() != ".jpg" and pathParts[-1].lower() != ".jpeg":
  253. print("excluding non-jpg!: " + str(thingy))
  254. inList.remove(thingy)
  255. return inList
  256. #Returns a list of filenames from the input directory
  257. def getListOfFilenames(baseDirectory):
  258. myNames = []
  259. for (root, dirNames, fileNames) in os.walk(baseDirectory):
  260. for aFile in fileNames:
  261. myNames.append( os.path.join( root, aFile ) )
  262. return myNames
  263. #Returns a list of dirnames from the base
  264. def getListOfDirNames(baseDirectory):
  265. myNames = []
  266. for (root, dirNames, fileNames) in os.walk(baseDirectory):
  267. for aDir in dirNames:
  268. myNames.append( os.path.join( root, aDir ) )
  269. return myNames
  270. def checkArgs(args):
  271. shouldIRetrieveImages = False
  272. #for people not using a terminal; they can set the flag.
  273. if IS_DOWNLOAD_PICTURES:
  274. shouldIRetrieveImages = True
  275. if len(args) > 1:
  276. downloadArgs = ["--download", "-download", "download", "d", "-d", "--d"]
  277. if not set(downloadArgs).isdisjoint(args):
  278. shouldIRetrieveImages = True
  279. #for the first time user
  280. if not isDownloadedFlagFileSet():
  281. shouldIRetrieveImages = True
  282. if shouldIRetrieveImages:
  283. retrieveImages()
  284. if __name__ == '__main__':
  285. import sys
  286. sys.exit(main(sys.argv))