main.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #include <timer.h>
  2. #include <betterthanmnist.h>
  3. #include <neuralnetwork.h>
  4. using namespace std;
  5. char intToChar(unsigned i){
  6. return char(i + 65);
  7. }
  8. float GetDataAccuracy (BetterThanMnist& data, NeuralNetwork& neuralNetwork, unsigned numTests)
  9. {
  10. unsigned correctItems = 0;
  11. for (unsigned i = 0, c = numTests; i < c; ++i)
  12. {
  13. unsigned label;
  14. vector<float> pixels = data.GetImage(label);
  15. unsigned detectedLabel = neuralNetwork.ForwardPass(pixels);
  16. if (detectedLabel == label)
  17. ++correctItems;
  18. }
  19. return float(correctItems) / float(numTests);
  20. }
  21. void ShowImage (vector<float>& pixels)
  22. {
  23. string render="";
  24. for (unsigned i = 0; i< 784; ++i){
  25. if (i % 28 == 0)
  26. render += '\n';
  27. if (pixels.at(i) == 1.0f)
  28. render += '#';
  29. else
  30. render += '.';
  31. }
  32. cout<<render<<endl;
  33. }
  34. void writeFloatsInFile(fstream& fout, const float* data, const unsigned arr_size){
  35. for(unsigned i = 0;i < arr_size;++i)
  36. fout<< data[i]<< ' ';
  37. fout<<endl;
  38. }
  39. void testNetwork(NeuralNetwork& neuralNetwork, BetterThanMnist& picture, unsigned testsNum){
  40. float accuracyTest = GetDataAccuracy(picture, neuralNetwork, testsNum);
  41. cout<<"Test network Accuracy: "<< 100.0f * accuracyTest<<"%"<<endl;
  42. }
  43. void saveResults(NeuralNetwork& neuralNetwork){
  44. fstream fout;
  45. fout.open(c_init_filename);
  46. const float* data = neuralNetwork.GetHiddenLayerBiases();
  47. writeFloatsInFile(fout, data, c_numHiddenNeurons);
  48. data = neuralNetwork.GetOutputLayerBiases();
  49. writeFloatsInFile(fout, data, c_numOutputNeurons);
  50. data = neuralNetwork.GetHiddenLayerWeights();
  51. writeFloatsInFile(fout, data, c_numInputNeurons*c_numHiddenNeurons);
  52. data = neuralNetwork.GetOutputLayerWeights();
  53. writeFloatsInFile(fout, data, c_numHiddenNeurons*c_numOutputNeurons);
  54. cout<<"[+] Results saved."<<endl;
  55. }
  56. void do_honset_test(NeuralNetwork& neuralNetwork, BetterThanMnist& picture, string filename){
  57. vector<float> pic = picture.GetTestImage(filename);
  58. unsigned detectedNum = neuralNetwork.ForwardPass(pic);
  59. ShowImage(pic);
  60. cout<<"Detected number: "<<intToChar(detectedNum);
  61. }
  62. void trainingNetwork(NeuralNetwork& neuralNetwork, BetterThanMnist& picture){
  63. Clock timer("Training Time: ");
  64. cout<<endl<<"[+] Training started--------------------------------"<<endl;
  65. for (unsigned generation = 0; generation < c_trainingGenerations; ++generation){
  66. neuralNetwork.Train(picture);
  67. cout<<"Training generation "<<generation + 1<<" / "<< c_trainingGenerations<<' ';
  68. float accuracy = GetDataAccuracy(picture, neuralNetwork, 20);
  69. cout<<"Test accuracy: "<<100.0f*accuracy<<'%'<<endl;
  70. }
  71. cout<<picture.NumImages()<<" tests passed."<<endl<<endl;
  72. timer.get_info();
  73. cout<<"[+] Training finished-------------------------------"<<endl<<"[+] Now testing."<<endl<<endl;
  74. float accuracyTest = GetDataAccuracy(picture, neuralNetwork, 40);
  75. cout<<"Final Accuracy Test: "<< 100.0f * accuracyTest<<"%"<<endl;
  76. cout<<endl<<"Do you want to save training results? (Y/N)"<<endl;
  77. string cmd; cin>>cmd;
  78. if (cmd == "Y"){
  79. saveResults(neuralNetwork);
  80. }
  81. else{
  82. cout<<"Results aren\'t saved"<<endl;
  83. }
  84. }
  85. void automaticTraining(NeuralNetwork& oldNet, BetterThanMnist& picture){
  86. picture.reopen();
  87. float currentAccuracy = GetDataAccuracy(picture, oldNet, 10000);
  88. cout<<"Old accuracy: "<<currentAccuracy*100<<"%"<<endl<<endl;
  89. picture.reopen();
  90. unsigned short bad_try_count = 0;
  91. Clock timer("Training Time: ");
  92. bool exit = true;
  93. cout<<endl<<"[+] Automatic training started"<<endl;
  94. do{
  95. cout<<"------------------------------"<<endl;
  96. srand(time(0));
  97. unsigned new_batch_size = rand()%10 + 2;
  98. float new_learnig_rate = random_device()()/(random_device().max()/3.2f) +.2f;
  99. NeuralNetwork net(new_batch_size, new_learnig_rate);
  100. net.initialize();
  101. cout<<"New bacth: "<<new_batch_size<<' '<<"New rate: "<<new_learnig_rate<<endl;
  102. for (unsigned generation = 0; generation < c_trainingGenerations; ++generation)
  103. net.Train(picture);
  104. cout<<"[+] tests passed."<<endl;
  105. picture.reopen();
  106. float newAccuracy = GetDataAccuracy(picture, net, 10000);
  107. cout<<"New accuracy: "<<newAccuracy*100<<"%"<<endl;
  108. picture.reopen();
  109. if (newAccuracy > currentAccuracy){
  110. saveResults(net);
  111. currentAccuracy = newAccuracy;
  112. bad_try_count = 0;
  113. }else {
  114. ++bad_try_count;
  115. }
  116. if (bad_try_count == 10)
  117. exit = false;
  118. }while(exit);
  119. timer.get_info();
  120. }
  121. int main (){
  122. NeuralNetwork neuralNetwork;
  123. BetterThanMnist picture;
  124. bool exit = true;
  125. do {
  126. cout<<endl<<"Enter 1 to test network."<<endl;
  127. cout<<"Enter 2 to initialize network."<<endl;
  128. cout<<"Enter 3 to train network."<<endl;
  129. cout<<"Enter 4 to do ONE recognition from your file."<<endl;
  130. cout<<"Enter 5 to reopen test file."<<endl;
  131. cout<<"Enter 6 to save weights."<<endl;
  132. cout<<"Enter 7 to automatic training."<<endl;
  133. cout<<"Enter 0 to exit"<<endl<<endl;
  134. char cmd; cin>>cmd;
  135. if (isdigit(cmd)){
  136. int choice = int(cmd) - 48;
  137. switch (choice) {
  138. case 0:
  139. exit = false;
  140. break;
  141. case 1:
  142. cout<<endl<<"How many tests do you want?"<<endl;
  143. unsigned b; cin>>b;
  144. testNetwork(neuralNetwork, picture, b);
  145. break;
  146. case 2:
  147. neuralNetwork.initialize();
  148. break;
  149. case 3:
  150. trainingNetwork(neuralNetwork, picture);
  151. break;
  152. case 5:
  153. picture.reopen();
  154. break;
  155. case 6:
  156. saveResults(neuralNetwork);
  157. break;
  158. case 7:
  159. automaticTraining(neuralNetwork, picture);
  160. break;
  161. case 4:
  162. string file;cout<<"Enter filename: ";cin>>file;
  163. do_honset_test(neuralNetwork, picture, file);
  164. break;
  165. }
  166. }else {
  167. exit = false;
  168. cout<<"Wrong input!";
  169. }
  170. }while(exit);
  171. return 0;
  172. }