FFNet_PatternList_Categories.cpp 4.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. /* FFNet_PatternList_Categories.cpp
  2. *
  3. * Copyright (C) 1994-2017 David Weenink
  4. *
  5. * This code is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 2 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This code is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. * General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this work. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. /*
  19. djmw 20020712 GPL header.
  20. djmw 20020910 changes.
  21. djmw 20030701 Removed non-GPL minimizations.
  22. djmw 20041118 Added FFNet_PatternList_Categories_getCosts.
  23. */
  24. #include "FFNet_ActivationList_Categories.h"
  25. #include "FFNet_PatternList_Categories.h"
  26. #include "FFNet_PatternList_ActivationList.h"
  27. static void _FFNet_PatternList_Categories_checkDimensions (FFNet me, PatternList p, Categories c) {
  28. Melder_require (my nInputs == p -> nx, U"The PatternList and the FFNet do not match.\nThe number of colums in the PatternList should equal the number of inputs in the FFNet.");
  29. Melder_require (p -> ny == c->size, U"The PatternList and the categories do not match.\nThe number of rows in the PatternList should equal the number of categories.");
  30. Melder_require (_PatternList_checkElements (p), U"All PatternList elements should be in the interval [0, 1].\nYou could use \"Formula...\" to scale the PatternList values first.");
  31. }
  32. double FFNet_PatternList_Categories_getCosts_total (FFNet me, PatternList p, Categories c, int costFunctionType) {
  33. try {
  34. _FFNet_PatternList_Categories_checkDimensions (me, p, c);
  35. autoActivationList activation = FFNet_Categories_to_ActivationList (me, c);
  36. return FFNet_PatternList_ActivationList_getCosts_total (me, p, activation.get(), costFunctionType);
  37. } catch (MelderError) {
  38. return undefined;
  39. }
  40. }
  41. double FFNet_PatternList_Categories_getCosts_average (FFNet me, PatternList p, Categories c, int costFunctionType) {
  42. double costs = FFNet_PatternList_Categories_getCosts_total (me, p, c, costFunctionType);
  43. return ( isundef (costs) ? undefined : costs / p -> ny );
  44. }
  45. void FFNet_PatternList_Categories_learnSD (FFNet me, PatternList p, Categories c, integer maxNumOfEpochs, double tolerance, double learningRate, double momentum, int costFunctionType) {
  46. _FFNet_PatternList_Categories_checkDimensions (me, p, c);
  47. autoActivationList activation = FFNet_Categories_to_ActivationList (me, c);
  48. double min, max;
  49. Matrix_getWindowExtrema (p, 0, 0, 0, 0, & min, & max);
  50. FFNet_PatternList_ActivationList_learnSD (me, p, activation.get(), maxNumOfEpochs, tolerance, learningRate, momentum, costFunctionType);
  51. }
  52. void FFNet_PatternList_Categories_learnSM (FFNet me, PatternList p, Categories c, integer maxNumOfEpochs, double tolerance, int costFunctionType) {
  53. _FFNet_PatternList_Categories_checkDimensions (me, p, c);
  54. autoActivationList activation = FFNet_Categories_to_ActivationList (me, c);
  55. double min, max;
  56. Matrix_getWindowExtrema (p, 0, 0, 0, 0, & min, & max);
  57. FFNet_PatternList_ActivationList_learnSM (me, p, activation.get(), maxNumOfEpochs, tolerance, costFunctionType);
  58. }
  59. autoCategories FFNet_PatternList_to_Categories (FFNet me, PatternList thee, int labeling) {
  60. try {
  61. Melder_require (my outputCategories, U"The FFNet has no output categories.");
  62. Melder_require (my nInputs == thy nx, U"The number of colums in the PatternList (", thy nx, U") should equal the number of inputs in the FFNet (", my nInputs, U").");
  63. Melder_require (_PatternList_checkElements (thee), U"All PatternList elements should be in the interval [0, 1].\nYou could use \"Formula...\" to scale the PatternList values first.");
  64. autoCategories him = Categories_create ();
  65. for (integer k = 1; k <= thy ny; k ++) {
  66. FFNet_propagate (me, thy z [k], nullptr);
  67. integer index = FFNet_getWinningUnit (me, labeling);
  68. autoSimpleString item = Data_copy (my outputCategories->at [index]);
  69. his addItem_move (item.move());
  70. }
  71. return him;
  72. } catch (MelderError) {
  73. Melder_throw (me, U": no Categories created.");
  74. }
  75. }
  76. /* End of file FFNet_PatternList_Categories.cpp */