load-dataset.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # load-dataset.py
  5. #
  6. # 2022 Stephen Stengel <stephen.stengel@cwu.edu>
  7. #
  8. # Script to load the dataset into file form for easy importing.
  9. #
  10. print("Loading imports...")
  11. import os
  12. import random
  13. import skimage
  14. from skimage.io import imsave
  15. from skimage.util import img_as_uint
  16. import shutil
  17. import subprocess
  18. import matplotlib.pyplot as plt
  19. from pathlib import Path
  20. from tqdm import tqdm #Pretty loading bars
  21. import numpy as np
  22. import tensorflow as tf
  23. print("Done!")
  24. CLASS_INTERESTING = 0
  25. CLASS_NOT_INTERESTING = 1
  26. CLASS_INTERESTING_STRING = "interesting"
  27. CLASS_NOT_INTERESTING_STRING = "not"
  28. DATASET_DIRECTORY = os.path.normpath("./ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/614s/")
  29. INTERESTING_DIRECTORY = os.path.join(DATASET_DIRECTORY, "interesting")
  30. NOT_INTERESTING_DIRECTORY = os.path.join(DATASET_DIRECTORY, "not interesting")
  31. DATASET_COPY_FOLDER = os.path.normpath("./tmpdata/")
  32. DATASET_COPY_FOLDER_INT = os.path.join(DATASET_COPY_FOLDER, CLASS_INTERESTING_STRING)
  33. DATASET_COPY_FOLDER_NOT = os.path.join(DATASET_COPY_FOLDER, CLASS_NOT_INTERESTING_STRING)
  34. DATASET_PNG_FOLDER = os.path.normpath("./datasets-as-png/")
  35. DATASET_PNG_FOLDER_TRAIN = os.path.join(DATASET_PNG_FOLDER, "train")
  36. DATASET_PNG_FOLDER_TRAIN_INT = os.path.join(DATASET_PNG_FOLDER_TRAIN, CLASS_INTERESTING_STRING)
  37. DATASET_PNG_FOLDER_TRAIN_NOT = os.path.join(DATASET_PNG_FOLDER_TRAIN, CLASS_NOT_INTERESTING_STRING)
  38. DATASET_PNG_FOLDER_VAL = os.path.join(DATASET_PNG_FOLDER, "val")
  39. DATASET_PNG_FOLDER_VAL_INT = os.path.join(DATASET_PNG_FOLDER_VAL, CLASS_INTERESTING_STRING)
  40. DATASET_PNG_FOLDER_VAL_NOT = os.path.join(DATASET_PNG_FOLDER_VAL, CLASS_NOT_INTERESTING_STRING)
  41. DATASET_PNG_FOLDER_TEST = os.path.join(DATASET_PNG_FOLDER, "test")
  42. DATASET_PNG_FOLDER_TEST_INT = os.path.join(DATASET_PNG_FOLDER_TEST, CLASS_INTERESTING_STRING)
  43. DATASET_PNG_FOLDER_TEST_NOT = os.path.join(DATASET_PNG_FOLDER_TEST, CLASS_NOT_INTERESTING_STRING)
  44. DATASET_SAVE_DIR = os.path.normpath("./dataset/")
  45. TRAIN_SAVE_DIRECTORY = os.path.join(DATASET_SAVE_DIR, "train")
  46. VAL_SAVE_DIRECTORY = os.path.join(DATASET_SAVE_DIR, "val")
  47. TEST_SAVE_DIRECTORY = os.path.join(DATASET_SAVE_DIR, "test")
  48. ALL_FOLDERS_LIST = [
  49. DATASET_DIRECTORY,
  50. INTERESTING_DIRECTORY,
  51. NOT_INTERESTING_DIRECTORY,
  52. DATASET_COPY_FOLDER,
  53. DATASET_COPY_FOLDER_INT,
  54. DATASET_COPY_FOLDER_NOT,
  55. DATASET_PNG_FOLDER,
  56. DATASET_PNG_FOLDER_TRAIN,
  57. DATASET_PNG_FOLDER_TRAIN_INT,
  58. DATASET_PNG_FOLDER_TRAIN_NOT,
  59. DATASET_PNG_FOLDER_VAL,
  60. DATASET_PNG_FOLDER_VAL_INT,
  61. DATASET_PNG_FOLDER_VAL_NOT,
  62. DATASET_PNG_FOLDER_TEST,
  63. DATASET_PNG_FOLDER_TEST_INT,
  64. DATASET_PNG_FOLDER_TEST_NOT,
  65. DATASET_SAVE_DIR,
  66. TRAIN_SAVE_DIRECTORY,
  67. VAL_SAVE_DIRECTORY,
  68. TEST_SAVE_DIRECTORY
  69. ]
  70. HIDDEN_DOWNLOAD_FLAG_FILE = ".isnotfirstdownload"
  71. CLASS_NAMES_LIST_INT = [CLASS_INTERESTING, CLASS_NOT_INTERESTING]
  72. CLASS_NAMES_LIST_STR = [CLASS_INTERESTING_STRING, CLASS_NOT_INTERESTING_STRING]
  73. TEST_PRINTING = False
  74. IS_SAVE_THE_DATASETS = True
  75. IS_SAVE_THE_PNGS = True
  76. IS_DOWNLOAD_PICTURES = False
  77. def main(args):
  78. random.seed()
  79. rInt = random.randint(0, 2**32 - 1) ##must be between 0 and 2**32 - 1
  80. print("Global level random seed: " + str(rInt))
  81. tf.keras.utils.set_random_seed(rInt) #sets random, numpy, and tensorflow seeds.
  82. #Tensorflow has a global random seed and a operation level seeds: https://www.tensorflow.org/api_docs/python/tf/random/set_seed
  83. print("Hello! This is the Animal Crossing Dataset Loader!")
  84. if not areAllProgramsInstalled():
  85. print("Not all programs installed.")
  86. exit(-2)
  87. makeDirectories(ALL_FOLDERS_LIST)
  88. checkArgs(args)
  89. print("DATASET_DIRECTORY: " + str(DATASET_DIRECTORY))
  90. print("Creating file structure...")
  91. createFileStructure(INTERESTING_DIRECTORY, DATASET_COPY_FOLDER_INT)
  92. createFileStructure(NOT_INTERESTING_DIRECTORY, DATASET_COPY_FOLDER_NOT)
  93. print("Done!")
  94. interestingFNames = getListOfAnimalPicsInOneClass(DATASET_COPY_FOLDER_INT)
  95. notInterestingFNames = getListOfAnimalPicsInOneClass(DATASET_COPY_FOLDER_NOT)
  96. #These WILL change later
  97. # ~ img_width = 400
  98. # ~ img_height = 300
  99. img_width = 100
  100. img_height = 100
  101. batch_size = 32
  102. percentageTrain = 0.6
  103. percentageTestToVal = 0.75
  104. print("creating the datasets...")
  105. train_ds, val_ds, test_ds = createAnimalsDataset(
  106. DATASET_COPY_FOLDER, img_height, img_width, batch_size, percentageTrain, percentageTestToVal)
  107. print("Done!")
  108. print("Saving datasets...")
  109. if IS_SAVE_THE_DATASETS:
  110. saveDatasets(
  111. train_ds, TRAIN_SAVE_DIRECTORY,
  112. val_ds, VAL_SAVE_DIRECTORY,
  113. test_ds, TEST_SAVE_DIRECTORY)
  114. print("Done!")
  115. else:
  116. print("Saving disabled for now!")
  117. if IS_SAVE_THE_PNGS:
  118. print("Saving the datasets as image files...")
  119. saveDatasetAsPNG(train_ds, DATASET_PNG_FOLDER_TRAIN)
  120. saveDatasetAsPNG(val_ds, DATASET_PNG_FOLDER_VAL)
  121. saveDatasetAsPNG(test_ds, DATASET_PNG_FOLDER_TEST)
  122. else:
  123. print("PNG saving disabled for now!")
  124. print("Deleting the temporary image folder...")
  125. shutil.rmtree(DATASET_COPY_FOLDER)
  126. if sys.platform.startswith("linux"):
  127. os.sync()
  128. print("Done!")
  129. return 0
  130. # Creates the necessary directories.
  131. def makeDirectories(listOfFoldersToCreate):
  132. #Clear out old files -- Justin Case
  133. if os.path.isdir(DATASET_SAVE_DIR):
  134. shutil.rmtree(DATASET_SAVE_DIR, ignore_errors = True)
  135. if os.path.isdir(DATASET_PNG_FOLDER):
  136. shutil.rmtree(DATASET_PNG_FOLDER, ignore_errors = True)
  137. if os.path.isdir(DATASET_COPY_FOLDER):
  138. shutil.rmtree(DATASET_COPY_FOLDER, ignore_errors = True)
  139. for folder in listOfFoldersToCreate:
  140. if not os.path.isdir(folder):
  141. os.makedirs(folder)
  142. # Retrieves the images if they're not here
  143. def retrieveImages():
  144. print("Retrieving images...")
  145. wgetString = "wget -e robots=off -r -np --mirror https://ftp.wsdot.wa.gov/public/I90Snoq/Biology/thermal/"
  146. if sys.platform.startswith("win"):
  147. os.system("wsl " + wgetString)
  148. elif sys.platform.startswith("linux"):
  149. os.system(wgetString)
  150. else:
  151. print("MASSIVE ERROR LOL!")
  152. exit(-4)
  153. print("Done!")
  154. #Checks if a flag file is in place to determine if the dataset should download from the ftp server.
  155. def isDownloadedFlagFileSet():
  156. if not os.path.isfile(HIDDEN_DOWNLOAD_FLAG_FILE):
  157. Path(HIDDEN_DOWNLOAD_FLAG_FILE).touch(exist_ok=True)
  158. return False
  159. return True
  160. #Takes some images from the validation set and sets the aside for the test set.
  161. def createTestSet(val_ds, percentageTestToVal):
  162. length = np.asarray(val_ds.cardinality())
  163. numForTest = int(length * percentageTestToVal)
  164. test_dataset = val_ds.take(numForTest)
  165. val_ds = val_ds.skip(numForTest)
  166. return val_ds, test_dataset
  167. def saveDatasets(train_ds, trainDir, val_ds, valDir, test_ds, testDir):
  168. tf.data.experimental.save(train_ds, trainDir)
  169. tf.data.experimental.save(val_ds, valDir)
  170. tf.data.experimental.save(test_ds, testDir)
  171. #The batching makes them get stuck together in batches. Right now that's 32 images.
  172. #So whenever you take one from the set, you get a batch of 32 images.
  173. # percentageTrain is a decimal from 0 to 1 of the percent data that should be for train
  174. # percentageTestToVal is a number from 0 to 1 of the percentage of the non-train data for use as test
  175. def createAnimalsDataset(baseDirectory, img_height, img_width, batch_size, percentageTrain, percentageTestToVal):
  176. valSplit = 1 - percentageTrain
  177. splitSeed = random.randint(0, 2**32 - 1)
  178. print("Operation level random seed for image_dataset_from_directory(): " + str(splitSeed))
  179. train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  180. baseDirectory,
  181. labels = "inferred",
  182. label_mode = "int",
  183. class_names = CLASS_NAMES_LIST_STR, #must match directory names
  184. color_mode = "grayscale",
  185. validation_split = valSplit,
  186. subset="training",
  187. seed = splitSeed,
  188. image_size=(img_height, img_width),
  189. batch_size=batch_size)
  190. val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  191. baseDirectory,
  192. labels = "inferred",
  193. label_mode = "int",
  194. class_names = CLASS_NAMES_LIST_STR, #must match directory names
  195. color_mode = "grayscale",
  196. validation_split = valSplit,
  197. subset="validation",
  198. seed = splitSeed,
  199. image_size=(img_height, img_width),
  200. batch_size=batch_size)
  201. val_ds, test_ds = createTestSet(val_ds, percentageTestToVal)
  202. AUTOTUNE = tf.data.AUTOTUNE
  203. normalization_layer = tf.keras.layers.Rescaling(1./255) #for newer versions of tensorflow
  204. # ~ normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255) #for old versions
  205. train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
  206. val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
  207. test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
  208. flippyBoy = tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal")
  209. train_ds = train_ds.map(lambda x, y: (flippyBoy(x), y), num_parallel_calls=AUTOTUNE)
  210. val_ds = val_ds.map(lambda x, y: (flippyBoy(x), y), num_parallel_calls=AUTOTUNE)
  211. myRotate = tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
  212. train_ds = train_ds.map(lambda x, y: (myRotate(x), y), num_parallel_calls=AUTOTUNE)
  213. val_ds = val_ds.map(lambda x, y: (myRotate(x), y), num_parallel_calls=AUTOTUNE)
  214. train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
  215. val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
  216. test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
  217. if TEST_PRINTING:
  218. print("Showing some unaltered images from the testing set...")
  219. printSample(test_ds)
  220. if TEST_PRINTING:
  221. print("Showing some augmented images from the training set...")
  222. printSample(train_ds)
  223. return train_ds, val_ds, test_ds
  224. # Prints first nine images from the first batch of the dataset.
  225. # It's random as long as you shuffle the dataset! ;)
  226. def printSample(in_ds):
  227. plt.figure(figsize=(10, 10))
  228. for img, label in in_ds.take(1):
  229. for i in tqdm(range(9)):
  230. ax = plt.subplot(3, 3, i + 1)
  231. myImg = np.asarray(img)
  232. plt.imshow(np.asarray(myImg[i]), cmap="gray")
  233. plt.title( CLASS_NAMES_LIST_STR[ np.asarray(label[i]) ] )
  234. plt.axis("off")
  235. plt.show()
  236. #save all images from dataset to file as png
  237. def saveDatasetAsPNG(in_ds, saveFolder):
  238. i = 0
  239. for batch in tqdm(in_ds):
  240. imgArr = np.asarray(batch[0])
  241. labelArr = np.asarray(batch[1])
  242. for j in range(len(imgArr)):
  243. thisImg = imgArr[j]
  244. thisImg = img_as_uint(thisImg)
  245. thisLabel = labelArr[j]
  246. filenamestring = os.path.join(saveFolder, CLASS_NAMES_LIST_STR[thisLabel], str(i) + ".png")
  247. imsave(filenamestring, thisImg)
  248. i = i + 1
  249. def createFileStructure(baseDirSource, destination):
  250. recursivelyCopyAllFilesInFolderToOneDestinationFolder(baseDirSource, destination)
  251. def recursivelyCopyAllFilesInFolderToOneDestinationFolder(baseDirSource, destination):
  252. print("Copying files to " + str(destination))
  253. cpyFiles = getListOfFilenames(baseDirSource)
  254. for thisName in tqdm(cpyFiles):
  255. try:
  256. shutil.copy(thisName, destination)
  257. except:
  258. print("copy skipping: " + str(thisName))
  259. def getListOfAnimalPicsInOneClass(classDir):
  260. dirNames = getListOfDirNames(classDir)
  261. picNames = []
  262. for dName in dirNames:
  263. picNames.extend( getCuratedListOfFileNames(dName) )
  264. return picNames
  265. def getCuratedListOfFileNames(directoryName):
  266. thisNames = getListOfFilenames(directoryName)
  267. thisNames = keepOnlyJPG(thisNames)
  268. return thisNames
  269. def keepOnlyJPG(inList):
  270. for thingy in inList:
  271. pathParts = os.path.splitext(thingy)
  272. if pathParts[-1].lower() != ".jpg" and pathParts[-1].lower() != ".jpeg":
  273. print("excluding non-jpg!: " + str(thingy))
  274. inList.remove(thingy)
  275. return inList
  276. #Returns a list of filenames from the input directory
  277. def getListOfFilenames(baseDirectory):
  278. myNames = []
  279. for (root, dirNames, fileNames) in os.walk(baseDirectory):
  280. for aFile in fileNames:
  281. myNames.append( os.path.join( root, aFile ) )
  282. return myNames
  283. #Returns a list of dirnames from the base
  284. def getListOfDirNames(baseDirectory):
  285. myNames = []
  286. for (root, dirNames, fileNames) in os.walk(baseDirectory):
  287. for aDir in dirNames:
  288. myNames.append( os.path.join( root, aDir ) )
  289. return myNames
  290. def checkArgs(args):
  291. shouldIRetrieveImages = False
  292. #for people not using a terminal; they can set the flag.
  293. if IS_DOWNLOAD_PICTURES:
  294. shouldIRetrieveImages = True
  295. if len(args) > 1:
  296. downloadArgs = ["--download", "-download", "download", "d", "-d", "--d"]
  297. if not set(downloadArgs).isdisjoint(args):
  298. shouldIRetrieveImages = True
  299. #for the first time user
  300. if not isDownloadedFlagFileSet():
  301. shouldIRetrieveImages = True
  302. if shouldIRetrieveImages:
  303. retrieveImages()
  304. ## Check if all the required programs are installed
  305. def areAllProgramsInstalled():
  306. if sys.platform.startswith("win"):
  307. if shutil.which("wsl") is not None:
  308. ## call a subprocess to run the "which wget" command inside wsl
  309. isWgetInstalled = subprocess.check_call(["wsl", "which", "wget"]) == 0
  310. if isWgetInstalled:
  311. return True
  312. else:
  313. print("Missing wget")
  314. return False
  315. else:
  316. print("Missing wsl")
  317. elif sys.platform.startswith("linux"):
  318. if (shutil.which("wget") is not None):
  319. return True
  320. else:
  321. print("Missing wget")
  322. return False
  323. else:
  324. print("Unsupportd operating system! Halting!")
  325. print("This os: " + str(sys.platform))
  326. return False
  327. return False
  328. if __name__ == '__main__':
  329. import sys
  330. sys.exit(main(sys.argv))