train-model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # train-model.py
  5. #
  6. # Copyright 2022 Stephen Stengel <stephen.stengel@cwu.edu>
  7. #
  8. print("Loading imports...")
  9. import os
  10. import tensorflow as tf
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import shutil
  14. import time
  15. from tqdm import tqdm
  16. from models import createHarlowModel, simpleModel
  17. from keras import callbacks
  18. from keras import backend
  19. print("Done!")
  20. LOADER_DIRECTORY = os.path.normpath("../animal-crossing-loader/")
  21. TRAIN_DIRECTORY = os.path.join(LOADER_DIRECTORY, "dataset", "train")
  22. VAL_DIRECTORY = os.path.join(LOADER_DIRECTORY, "dataset", "val")
  23. TEST_DIRECTORY = os.path.join(LOADER_DIRECTORY, "dataset", "test")
  24. CLASS_INTERESTING = 0
  25. CLASS_NOT_INTERESTING = 1
  26. CLASS_INTERESTING_STRING = "interesting"
  27. CLASS_NOT_INTERESTING_STRING = "not"
  28. CLASS_NAMES_LIST_INT = [CLASS_INTERESTING, CLASS_NOT_INTERESTING]
  29. CLASS_NAMES_LIST_STR = [CLASS_INTERESTING_STRING, CLASS_NOT_INTERESTING_STRING]
  30. TEST_PRINTING = True
  31. IMG_WIDTH = 100
  32. IMG_HEIGHT = 100
  33. # ~ IMG_WIDTH = 200
  34. # ~ IMG_HEIGHT = 150
  35. # ~ IMG_WIDTH = 400
  36. # ~ IMG_HEIGHT = 300
  37. IMG_CHANNELS = 1
  38. IMG_SHAPE_TUPPLE = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
  39. def main(args):
  40. listOfFoldersToDELETE = []
  41. deleteDirectories(listOfFoldersToDELETE)
  42. #base folder for this run
  43. ts = time.localtime()
  44. timeStr = "./%d-%d-%d-%d-%d-%d/" % (ts.tm_year, ts.tm_mon, ts.tm_mday, ts.tm_hour, ts.tm_min, ts.tm_sec)
  45. timeStr = os.path.normpath(timeStr)
  46. # Folders to save model tests
  47. simpleFolder = os.path.join(timeStr, "simple")
  48. harlowFolder = os.path.join(timeStr, "harlow")
  49. modelBaseFolders = [simpleFolder, harlowFolder] #Same order as the modelList below!
  50. makeDirectories(modelBaseFolders)
  51. # train_ds is for training the model.
  52. # val_ds is for validation during training.
  53. # test_ds is a dataset of unmodified images for testing the model after training.
  54. train_ds, val_ds, test_ds = getDatasets(TRAIN_DIRECTORY, VAL_DIRECTORY, TEST_DIRECTORY)
  55. if TEST_PRINTING:
  56. printSample(test_ds)
  57. shape = IMG_SHAPE_TUPPLE
  58. modelList = [simpleModel(shape), createHarlowModel(shape)]
  59. # This for loop can be compartmentalized into helper functions.
  60. # There will be one wrapper function to perform k-folds
  61. # something like performExperiment -> performKfolds -> contents of this for loop.
  62. for i in range(len(modelList)):
  63. thisModel = modelList[i]
  64. thisModel.summary()
  65. thisOutputFolder = modelBaseFolders[i]
  66. print("Training model: " + thisOutputFolder)
  67. thisCheckpointFolder = os.path.join(thisOutputFolder, "checkpoint")
  68. foldersForThisModel = [thisOutputFolder, thisCheckpointFolder]
  69. makeDirectories(foldersForThisModel)
  70. #save copy of source code that created the output
  71. os.system("cp train-model.py " + os.path.join(thisOutputFolder, "train-model.py"))
  72. myHistory = trainModel(thisModel, train_ds, val_ds, thisCheckpointFolder)
  73. print("Creating graphs of training history...")
  74. strAcc, strLoss = saveGraphs(thisModel, myHistory, test_ds, thisOutputFolder)
  75. #workin on this.
  76. stringToPrint = evaluateLabels(test_ds, thisModel, thisOutputFolder)
  77. stringToPrint += "Accuracy and loss according to tensorflow model.evaluate():\n"
  78. stringToPrint += strAcc + "\n"
  79. stringToPrint += strLoss + "\n"
  80. statFileName = os.path.join(thisOutputFolder, "stats.txt")
  81. printStringToFile(statFileName, stringToPrint, "w")
  82. print(stringToPrint)
  83. return 0
  84. # model.predict() makes an array of probabilities that a certian class is correct.
  85. # By saving the scores from the test_ds, we can see which images
  86. # cause false-positives, false-negatives, true-positives, and true-negatives
  87. def evaluateLabels(test_ds, model, outputFolder):
  88. print("Getting predictions of test data...")
  89. testScores = model.predict(test_ds, verbose = True)
  90. actual_test_labels = extractLabels(test_ds)
  91. #Get the list of class predictions from the probability scores.
  92. p_test_labels = getPredictedLabels(testScores)
  93. printLabelStuffToFile(testScores, actual_test_labels, p_test_labels, outputFolder) # debug function
  94. #Calculate TPR, FPR, TNR, FNR
  95. outString = ""
  96. tp_sum = getTPsum(actual_test_labels, p_test_labels)
  97. outString += "truePos: " + str(tp_sum) + "\n"
  98. tn_sum = getTNsum(actual_test_labels, p_test_labels)
  99. outString += "true negative: " + str(tn_sum) + "\n"
  100. fp_sum = getFPsum(actual_test_labels, p_test_labels)
  101. outString += "false pos: " + str(fp_sum) + "\n"
  102. fn_sum = getFNsum(actual_test_labels, p_test_labels)
  103. outString += "false negative: " + str(fn_sum) + "\n"
  104. accuracy = getAcc(tp_sum, tn_sum, fp_sum, fn_sum)
  105. outString += "accuracy: " + str(accuracy) + "\n"
  106. err = getErrRate(tp_sum, tn_sum, fp_sum, fn_sum)
  107. outString += "error rate: " + str(err) + "\n"
  108. tpr = getTPR(tp_sum, fn_sum)
  109. outString += "True Positive Rate: " + str(tpr) + "\n"
  110. tNr = getTNR(tn_sum, fp_sum)
  111. outString += "True Negative Rate: " + str(tNr) + "\n"
  112. precision = getPrecision(tp_sum, fp_sum)
  113. outString += "Precision: " + str(precision) + "\n"
  114. #Save the false positive, false negative images into folders.
  115. #Make a pretty chart of these images?
  116. return outString
  117. def getAcc(tp, tn, fp, fn):
  118. top = tp + tn
  119. bottom = tp + fp + tn + fn
  120. return top / bottom
  121. def getErrRate(tp, tn, fp, fn):
  122. return 1 - getAcc(tp, tn, fp, fn)
  123. # Also known as Sensitivity, recall, and hit rate.
  124. def getTPR(tp, fn):
  125. return tp / (tp + fn)
  126. # Also known as Specificity and selectivity
  127. def getTNR(tn, fp):
  128. return tn / (tn + fp)
  129. # Also known as positive predictive value
  130. def getPrecision(truePos, falsePos):
  131. return truePos / (truePos + falsePos)
  132. # have to think how to do the mask to go very fast.
  133. # i'll just do a loop for now
  134. # I think a lambda function thing would work.
  135. def getTPsum(actual_test_labels, p_test_labels):
  136. sumList = []
  137. for i in range(len(actual_test_labels)):
  138. if (actual_test_labels[i] == CLASS_INTERESTING) and (actual_test_labels[i] == p_test_labels[i]):
  139. sumList.append(1)
  140. else:
  141. sumList.append(0)
  142. sumArr = np.asarray(sumList)
  143. return np.asarray(backend.sum(sumArr))
  144. def getTNsum(actual_test_labels, p_test_labels):
  145. sumList = []
  146. for i in range(len(actual_test_labels)):
  147. if (actual_test_labels[i] == CLASS_NOT_INTERESTING) and (actual_test_labels[i] == p_test_labels[i]):
  148. sumList.append(1)
  149. else:
  150. sumList.append(0)
  151. sumArr = np.asarray(sumList)
  152. return np.asarray(backend.sum(sumArr))
  153. def getFPsum(actual_test_labels, p_test_labels):
  154. sumList = []
  155. for i in range(len(actual_test_labels)):
  156. if (actual_test_labels[i] == CLASS_NOT_INTERESTING) and (actual_test_labels[i] != p_test_labels[i]):
  157. sumList.append(1)
  158. else:
  159. sumList.append(0)
  160. sumArr = np.asarray(sumList)
  161. return np.asarray(backend.sum(sumArr))
  162. def getFNsum(actual_test_labels, p_test_labels):
  163. sumList = []
  164. for i in range(len(actual_test_labels)):
  165. if (actual_test_labels[i] == CLASS_INTERESTING) and (actual_test_labels[i] != p_test_labels[i]):
  166. sumList.append(1)
  167. else:
  168. sumList.append(0)
  169. sumArr = np.asarray(sumList)
  170. return np.asarray(backend.sum(sumArr))
  171. # Creates the necessary directories.
  172. def makeDirectories(listOfFoldersToCreate):
  173. for folder in listOfFoldersToCreate:
  174. if not os.path.isdir(folder):
  175. os.makedirs(folder)
  176. def deleteDirectories(listDirsToDelete):
  177. for folder in listDirsToDelete:
  178. if os.path.isdir(folder):
  179. shutil.rmtree(folder, ignore_errors = True)
  180. # add checkpointer, earlystopper?
  181. def trainModel(model, train_ds, val_ds, checkpointFolder):
  182. checkpointer = callbacks.ModelCheckpoint(
  183. filepath = checkpointFolder,
  184. monitor = "accuracy",
  185. save_best_only = True,
  186. mode = "max")
  187. earlyStopper = callbacks.EarlyStopping(monitor="accuracy", patience = 10)
  188. callbacks_list = [earlyStopper, checkpointer]
  189. return model.fit(
  190. train_ds,
  191. # ~ steps_per_epoch = 1, #to shorten training for testing purposes. I got no gpu qq.
  192. callbacks = callbacks_list,
  193. epochs = 100,
  194. validation_data = val_ds)
  195. def saveGraphs(model, myHistory, test_ds, outputFolder):
  196. evalLoss, evalAccuracy = model.evaluate(test_ds)
  197. plt.clf()
  198. accuracy = myHistory.history['accuracy']
  199. val_accuracy = myHistory.history["val_accuracy"]
  200. epochs = range(1, len(accuracy) + 1)
  201. accCap = round(evalAccuracy, 4)
  202. captionTextAcc = "Accuracy on test data: {}".format(accCap)
  203. plt.figtext(0.5, 0.01, captionTextAcc, wrap=True, horizontalalignment='center', fontsize=12)
  204. plt.plot(epochs, accuracy, "o", label="Training accuracy")
  205. plt.plot(epochs, val_accuracy, "^", label="Validation accuracy")
  206. plt.title("Model Accuracy vs Epochs")
  207. plt.ylabel("accuracy")
  208. plt.xlabel("epoch")
  209. plt.legend()
  210. plt.savefig(os.path.join(outputFolder, "trainvalacc.png"))
  211. plt.clf()
  212. loss = myHistory.history["loss"]
  213. val_loss = myHistory.history["val_loss"]
  214. lossCap = round(evalLoss, 4)
  215. captionTextLoss = "Loss on test data: {}".format(lossCap)
  216. plt.figtext(0.5, 0.01, captionTextLoss, wrap=True, horizontalalignment='center', fontsize=12)
  217. plt.plot(epochs, loss, "o", label="Training loss")
  218. plt.plot(epochs, val_loss, "^", label="Validation loss")
  219. plt.title("Training and validation loss vs Epochs")
  220. plt.ylabel("loss")
  221. plt.xlabel("epoch")
  222. plt.legend()
  223. plt.savefig(os.path.join(outputFolder, "trainvalloss.png"))
  224. plt.clf()
  225. return captionTextAcc, captionTextLoss
  226. def getDatasets(trainDir, valDir, testDir):
  227. train = tf.data.experimental.load(trainDir)
  228. val = tf.data.experimental.load(valDir)
  229. test = tf.data.experimental.load(testDir)
  230. return train, val, test
  231. # Prints first nine images from the first batch of the dataset.
  232. # It's random as long as you shuffle the dataset! ;)
  233. def printSample(in_ds):
  234. plt.figure(figsize=(10, 10))
  235. for img, label in in_ds.take(1):
  236. # ~ for i in tqdm.tqdm(range(9)):
  237. for i in tqdm(range(9)):
  238. ax = plt.subplot(3, 3, i + 1)
  239. myImg = np.asarray(img)
  240. plt.imshow(np.asarray(myImg[i]), cmap="gray")
  241. plt.title( CLASS_NAMES_LIST_STR[ np.asarray(label[i]) ] )
  242. plt.axis("off")
  243. plt.show()
  244. plt.clf()
  245. # Extract the labels from the tensorflow dataset structure.
  246. def extractLabels(in_ds):
  247. print("Trying to get list out of test dataset...")
  248. lablist = []
  249. for batch in tqdm(in_ds):
  250. lablist.extend( np.asarray(batch[1]) )
  251. return np.asarray(lablist)
  252. def printStringToFile(fileName, textString, openMode):
  253. with open(fileName, openMode) as myFile:
  254. for character in textString:
  255. myFile.write(character)
  256. def printLabelStuffToFile(predictedScores, originalLabels, predictedLabels, outputFolder):
  257. with open(os.path.join(outputFolder, "predictionlists.txt"), "w") as outFile:
  258. for i in range(len(predictedScores)):
  259. thisScores = predictedScores[i]
  260. intScore = str(round(thisScores[CLASS_INTERESTING], 4))
  261. notScore = str(round(thisScores[CLASS_NOT_INTERESTING], 4))
  262. thisString = \
  263. "predicted score int,not: [" + intScore + ", " + notScore + "]" \
  264. + "\tactual label " + str(originalLabels[i]) \
  265. + "\tpredicted label" + str(predictedLabels[i]) + "\n"
  266. outFile.write(thisString)
  267. def getPredictedLabels(testScores):
  268. outList = []
  269. for score in testScores:
  270. if score[CLASS_INTERESTING] >= score[CLASS_NOT_INTERESTING]:
  271. outList.append(CLASS_INTERESTING)
  272. else:
  273. outList.append(CLASS_NOT_INTERESTING)
  274. return np.asarray(outList)
  275. if __name__ == '__main__':
  276. import sys
  277. sys.exit(main(sys.argv))