load-dataset.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # load-dataset.py
  5. #
  6. # 2022 Stephen Stengel <stephen.stengel@cwu.edu>
  7. #
  8. # Script to load the dataset into file form for easy importing.
  9. #
  10. print("Loading imports...")
  11. import os
  12. import skimage
  13. import shutil
  14. import matplotlib.pyplot as plt
  15. from pathlib import Path
  16. import subprocess
  17. import numpy as np
  18. import tensorflow as tf
  19. print("Done!")
  20. DATASET_DIRECTORY = "./ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/614s/"
  21. INTERESTING_DIRECTORY = "./ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/614s/interesting/"
  22. NOT_INTERESTING_DIRECTORY = "./ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/614s/not interesting/"
  23. DATASET_COPY_FOLDER = "./tmpdata/"
  24. DATASET_COPY_FOLDER_INT = "./tmpdata/int/"
  25. DATASET_COPY_FOLDER_NOT = "./tmpdata/not/"
  26. DATASET_SAVE_DIR = "./dataset/"
  27. TRAIN_SAVE_DIRECTORY = "./dataset/train/"
  28. VAL_SAVE_DIRECTORY = "./dataset/val/"
  29. TEST_SAVE_DIRECTORY = "./dataset/test/"
  30. CLASS_INTERESTING = 0
  31. CLASS_NOT_INTERESTING = 1
  32. TEST_PRINTING = False
  33. IS_DOWNLOAD_PICTURES = False
  34. HIDDEN_DOWNLOAD_FLAG_FILE = ".isnotfirstdownload"
  35. def main(args):
  36. print("Hello! This is the Animal Crossing Dataset Loader!")
  37. makeDirectories()
  38. wgetPID = checkArgs(args)
  39. print("DATASET_DIRECTORY: " + str(DATASET_DIRECTORY))
  40. print("Creating file structure...")
  41. createFileStructure(INTERESTING_DIRECTORY, DATASET_COPY_FOLDER_INT)
  42. createFileStructure(NOT_INTERESTING_DIRECTORY, DATASET_COPY_FOLDER_NOT)
  43. print("Done!")
  44. interestingFNames = getListOfAnimalPicsInOneClass(DATASET_COPY_FOLDER_INT)
  45. notInterestingFNames = getListOfAnimalPicsInOneClass(DATASET_COPY_FOLDER_NOT)
  46. #This is only useful if the files are already downloaded. Waits for
  47. #wget to finish updating after finishing createFileStructure and getting names.
  48. waitForDownloadToFinish(wgetPID)
  49. #These could change later
  50. img_height = 100
  51. img_width = 100
  52. # ~ img_height = 600
  53. # ~ img_width = 800
  54. batch_size = 32
  55. print("creating the datasets...")
  56. train_ds, val_ds, test_ds = createAnimalsDataset(DATASET_COPY_FOLDER, img_height, img_width, batch_size)
  57. print("Done!")
  58. #Might not be super useful but it's possible
  59. #https://www.tensorflow.org/api_docs/python/tf/data/experimental/save
  60. print("Saving datasets...")
  61. # ~ tf.data.experimental.save(train_ds, TRAIN_SAVE_DIRECTORY)
  62. # ~ tf.data.experimental.save(val_ds, VAL_SAVE_DIRECTORY)
  63. # ~ tf.data.experimental.save(test_ds, TEST_SAVE_DIRECTORY)
  64. print("Disabled for now!")
  65. print("Done!")
  66. return 0
  67. #There is an easier way.
  68. def makeDirectories():
  69. if not os.path.isdir(DATASET_COPY_FOLDER):
  70. os.mkdir(DATASET_COPY_FOLDER)
  71. if not os.path.isdir(DATASET_COPY_FOLDER_INT):
  72. os.mkdir(DATASET_COPY_FOLDER_INT)
  73. if not os.path.isdir(DATASET_COPY_FOLDER_NOT):
  74. os.mkdir(DATASET_COPY_FOLDER_NOT)
  75. if not os.path.isdir(DATASET_SAVE_DIR):
  76. os.mkdir(DATASET_SAVE_DIR)
  77. if not os.path.isdir(TRAIN_SAVE_DIRECTORY):
  78. os.mkdir(TRAIN_SAVE_DIRECTORY)
  79. if not os.path.isdir(VAL_SAVE_DIRECTORY):
  80. os.mkdir(VAL_SAVE_DIRECTORY)
  81. if not os.path.isdir(TEST_SAVE_DIRECTORY):
  82. os.mkdir(TEST_SAVE_DIRECTORY)
  83. if not os.path.isdir(DATASET_DIRECTORY):
  84. os.makedirs(DATASET_DIRECTORY)
  85. # Retrieves the images if they're not here
  86. # note: does not UPDATE images, need to implement that
  87. def retrieveImages():
  88. print("Retrieving images...")
  89. # ~ wgetPID = subprocess.Popen(["ogg123", "james-brown-dead.ogg", "-r", "-q"])
  90. wgetPID = subprocess.Popen(["wget", "-e", "robots=off", "-r", "-np", "--mirror", "https://ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/"])
  91. # ~ os.system("wget -e robots=off -r -np --mirror https://ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/")
  92. print("Done!")
  93. return wgetPID
  94. #Checks if a flag file is in place to determine if the dataset should download from the ftp server.
  95. def isDownloadedFlagFileSet():
  96. if not os.path.isfile(HIDDEN_DOWNLOAD_FLAG_FILE):
  97. Path(HIDDEN_DOWNLOAD_FLAG_FILE).touch(exist_ok=True)
  98. return False
  99. return True
  100. #Waits for wget to finish. wgetPID is the process id of the wget subprocess.
  101. #Not sure if it will work on windows!!!!!!!!!!!!!!!
  102. def waitForDownloadToFinish(wgetPID):
  103. if wgetPID is not None:
  104. wgetPID.wait()
  105. print(" -- Finished Downloading Dataset -- ")
  106. print("wget done")
  107. #Takes some images from the validation set and sets the aside for the test set.
  108. def createTestSet(val_ds):
  109. val_batches = tf.data.experimental.cardinality(val_ds)
  110. test_dataset = val_ds.take(val_batches // 5)
  111. val_ds = val_ds.skip(val_batches // 5)
  112. return val_ds, test_dataset
  113. #Must use tf.keras.layers.Rescaling(1./255) as first layer in model !!!
  114. def createAnimalsDataset(baseDirectory, img_height, img_width, batch_size):
  115. train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  116. baseDirectory,
  117. color_mode = "rgb",
  118. validation_split=0.2,
  119. subset="training",
  120. seed=123,
  121. image_size=(img_height, img_width),
  122. batch_size=batch_size)
  123. val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  124. baseDirectory,
  125. color_mode = "rgb",
  126. validation_split=0.2,
  127. subset="validation",
  128. seed=123,
  129. image_size=(img_height, img_width),
  130. batch_size=batch_size)
  131. if TEST_PRINTING:
  132. plt.figure(figsize=(10, 10))
  133. for images, labels in train_ds.take(1):
  134. for i in range(9):
  135. ax = plt.subplot(3, 3, i + 1)
  136. plt.imshow(images[i].numpy().astype("uint8"))
  137. plt.title(class_names[labels[i]])
  138. plt.axis("off")
  139. plt.show()
  140. class_names = train_ds.class_names
  141. print("class names: " + str(class_names))
  142. # ~ normalization_layer = tf.keras.layers.Rescaling(1./255) #for new versions
  143. normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255) #for old versions
  144. n_train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
  145. n_val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
  146. n_val_ds, n_test_ds = createTestSet(n_val_ds)
  147. #names change
  148. AUTOTUNE = tf.data.AUTOTUNE
  149. n_train_ds = n_train_ds.prefetch(buffer_size=AUTOTUNE)
  150. n_val_ds = n_val_ds.prefetch(buffer_size=AUTOTUNE)
  151. n_test_ds = n_test_ds.prefetch(buffer_size=AUTOTUNE)
  152. #could do augmentation here on train and val, leaving test unaugmented.
  153. #causing errors. skipped for now.
  154. #I think you have to uncouple the dataset from the extra data that I added in the normalization and prefetch steps
  155. #see : https://www.tensorflow.org/text/tutorials/transformer
  156. # ~ flippyBoy = tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal")
  157. # ~ rotate = tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
  158. # ~ n_train_ds = rotate(n_train_ds)
  159. # ~ n_val_ds = rotate(n_val_ds)
  160. # ~ n_test_ds = rotate(n_test_ds)
  161. return n_train_ds, n_val_ds, n_test_ds
  162. def createFileStructure(baseDirSource, destination):
  163. copyDatasetToTMP(baseDirSource, destination)
  164. dirNames = getListOfDirNames(destination)
  165. for dName in dirNames:
  166. copyDatasetToTMP(dName, destination)
  167. def copyDatasetToTMP(baseDirSource, destination):
  168. cpyFiles = getListOfFilenames(baseDirSource)
  169. for thisName in cpyFiles:
  170. try:
  171. shutil.copy(thisName, destination)
  172. except:
  173. print("copy skipping: " + str(thisName))
  174. def getListOfAnimalPicsInOneClass(classDir):
  175. dirNames = getListOfDirNames(classDir)
  176. picNames = []
  177. for dName in dirNames:
  178. picNames.extend( getCuratedListOfFileNames(dName) )
  179. return picNames
  180. def getCuratedListOfFileNames(directoryName):
  181. thisNames = getListOfFilenames(directoryName)
  182. thisNames = keepOnlyJPG(thisNames)
  183. return thisNames
  184. def keepOnlyJPG(inList):
  185. for thingy in inList:
  186. pathParts = os.path.splitext(thingy)
  187. if pathParts[-1].lower() != ".jpg" and pathParts[-1].lower() != ".jpeg":
  188. print("excluding non-jpg!: " + str(thingy))
  189. inList.remove(thingy)
  190. return inList
  191. #Returns a list of filenames from the input directory
  192. def getListOfFilenames(baseDirectory):
  193. myNames = []
  194. for (root, dirNames, fileNames) in os.walk(baseDirectory):
  195. for aFile in fileNames:
  196. myNames.append( os.path.join( root, aFile ) )
  197. return myNames
  198. #Returns a list of dirnames from the base
  199. def getListOfDirNames(baseDirectory):
  200. myNames = []
  201. for (root, dirNames, fileNames) in os.walk(baseDirectory):
  202. for aDir in dirNames:
  203. myNames.append( os.path.join( root, aDir ) )
  204. return myNames
  205. def checkArgs(args):
  206. wgetPID = None
  207. #for people not using a terminal; they can set the flag.
  208. if IS_DOWNLOAD_PICTURES:
  209. wgetPID = retrieveImages()
  210. if len(args) > 1:
  211. downloadArgs = ["--download", "-download", "download", "d", "-d", "--d"]
  212. if not set(downloadArgs).isdisjoint(args):
  213. wgetPID = retrieveImages()
  214. #for the first time user
  215. if not isDownloadedFlagFileSet():
  216. wgetPID = retrieveImages()
  217. return wgetPID
  218. if __name__ == '__main__':
  219. import sys
  220. sys.exit(main(sys.argv))