api.mrnn.php 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. <?php
  2. /**
  3. * Most Retarded Neural Network ever. Yep, with single neuron.
  4. */
  5. class MRNN {
  6. /**
  7. * Initial weight
  8. *
  9. * @var float
  10. */
  11. protected $weight = 0.01;
  12. /**
  13. * Last neuron training error
  14. *
  15. * @var float
  16. */
  17. protected $lastError = 1;
  18. /**
  19. * Smoothing factor
  20. *
  21. * @var float
  22. */
  23. protected $smoothing = 0.0001;
  24. /**
  25. * Training routine result
  26. *
  27. * @var float
  28. */
  29. protected $actualResult = 0.01;
  30. /**
  31. * Contains current weight correction
  32. *
  33. * @var float
  34. */
  35. protected $correction = 0;
  36. /**
  37. * Contains current training iteration
  38. *
  39. * @var int
  40. */
  41. protected $epoch = 0;
  42. /**
  43. * Contains current training stats as epoch=>error
  44. *
  45. * @var array
  46. */
  47. protected $trainStats = array();
  48. /**
  49. * Contains train stats multiplier
  50. *
  51. * @var int
  52. */
  53. protected $statEvery = 5000;
  54. /**
  55. * Output of debug messages due train progress
  56. *
  57. * @var bool
  58. */
  59. protected $debug = false;
  60. /**
  61. * Contains network activation function type
  62. *
  63. * @var string
  64. */
  65. protected $activationFunction = 'def';
  66. /**
  67. * What did you expect?
  68. *
  69. * @param string $activationFunction activation function def or sigmoid
  70. */
  71. public function __construct($activationFunction = 'def') {
  72. $this->setActivationFunc($activationFunction);
  73. }
  74. /**
  75. * Sets neuron instance weight
  76. *
  77. * @param float $weight
  78. *
  79. * @return void
  80. */
  81. public function setWeight($weight) {
  82. $this->weight = $weight;
  83. }
  84. /**
  85. * Sets network instance activation function type
  86. *
  87. * @param string $type
  88. */
  89. protected function setActivationFunc($type) {
  90. $supportedTypes = array(
  91. 'def' => 'def',
  92. 'sigmoid' => 'sigmoid'
  93. );
  94. if (isset($supportedTypes[$type])) {
  95. $this->activationFunction = $type;
  96. } else {
  97. throw new Exception('EX_WRONG_ACTFUNCTION');
  98. }
  99. }
  100. /**
  101. * Returns data output processed by trained neuron (forward)
  102. *
  103. * @param float $input
  104. *
  105. * @return float
  106. */
  107. public function processInputData($input) {
  108. $result = $input * $this->weight;
  109. return($result);
  110. }
  111. /**
  112. * Returns data input processed by trained neuron (backward)
  113. *
  114. * @param float $output
  115. *
  116. * @return float
  117. */
  118. public function restoreInputData($output) {
  119. $result = $output / $this->weight;
  120. return($result);
  121. }
  122. /**
  123. * Just native sigmoid function
  124. *
  125. * @param float $value
  126. *
  127. * @return float
  128. */
  129. protected function sigmoid($value) {
  130. return (1 / (1 + exp(-$value)));
  131. }
  132. /**
  133. * Inverse of native sigmoid function
  134. *
  135. * @param float $value
  136. *
  137. * @return float
  138. */
  139. protected function unsigmoid($value) {
  140. return (log($value / (1 - $value)));
  141. }
  142. /**
  143. * Do the neuron train routine
  144. *
  145. * @param float $input
  146. * @param float $expectedResult
  147. *
  148. * @return void
  149. */
  150. protected function train($input, $expectedResult) {
  151. switch ($this->activationFunction) {
  152. case 'def':
  153. $this->actualResult = $input * $this->weight;
  154. $this->lastError = $expectedResult - $this->actualResult;
  155. $this->correction = ($this->lastError / $this->actualResult) * $this->smoothing;
  156. $this->weight += $this->correction;
  157. break;
  158. case 'sigmoid':
  159. $this->actualResult = $input * $this->weight;
  160. $this->actualResult = $this->sigmoid($this->actualResult);
  161. $this->lastError = $expectedResult - $this->unsigmoid($this->actualResult);
  162. $this->correction = ($this->lastError / $this->actualResult) * $this->smoothing;
  163. $this->weight += $this->correction;
  164. break;
  165. }
  166. }
  167. /**
  168. * Train neural network on some single input value
  169. *
  170. * @param float $input
  171. * @param float $expectedResult
  172. *
  173. * @return bool
  174. */
  175. protected function learn($input, $expectedResult) {
  176. $this->epoch = 0;
  177. while ($this->lastError > $this->smoothing OR $this->lastError < '-' . $this->smoothing) {
  178. $this->train($input, $expectedResult);
  179. //log train stats
  180. if (($this->epoch % $this->statEvery) == 0) {
  181. $this->trainStats[$this->epoch] = $this->lastError;
  182. }
  183. $this->epoch++;
  184. }
  185. return(true);
  186. }
  187. /**
  188. * Performs training of neural network with
  189. *
  190. * @param array $dataSet inputs data array like array(inputValue=>estimatedValue)
  191. * @param bool $accel perform learning optimizations with previous weight inherition
  192. *
  193. * @return bool
  194. */
  195. public function learnDataSet($dataSet, $accel = false) {
  196. $result = false;
  197. if (is_array($dataSet)) {
  198. if (!empty($dataSet)) {
  199. $totalweight = 0;
  200. $neurons = array();
  201. $neuronIndex = 0;
  202. $prevWeight = $this->weight;
  203. $networkName = get_class($this);
  204. foreach ($dataSet as $input => $expectedResult) {
  205. $neurons[$neuronIndex] = new $networkName($this->activationFunction);
  206. //optional learning acceleration via next weight correction
  207. if ($accel) {
  208. $neurons[$neuronIndex]->setWeight($prevWeight);
  209. }
  210. if ($neurons[$neuronIndex]->learn($input, $expectedResult)) {
  211. if ($this->debug) {
  212. show_success('Trained weight: ' . $neurons[$neuronIndex]->getWeight() . ' on epoch ' . $neurons[$neuronIndex]->getEpoch());
  213. }
  214. $totalweight += $neurons[$neuronIndex]->getWeight();
  215. $this->trainStats[] = $neurons[$neuronIndex]->getTrainStats();
  216. $prevWeight = $neurons[$neuronIndex]->getWeight();
  217. unset($neurons[$neuronIndex]);
  218. }
  219. $neuronIndex++;
  220. }
  221. $this->weight = $totalweight / $neuronIndex; //learning complete
  222. }
  223. }
  224. $result = true;
  225. return($result);
  226. }
  227. /**
  228. * Retrurns current network instance training stats
  229. *
  230. * @return array
  231. */
  232. public function getTrainStats() {
  233. return($this->trainStats);
  234. }
  235. /**
  236. * Returns current neuron weight
  237. *
  238. * @return float
  239. */
  240. public function getWeight() {
  241. return($this->weight);
  242. }
  243. /**
  244. * Returns current train last error
  245. *
  246. * @return float
  247. */
  248. protected function getLastError() {
  249. return($this->lastError);
  250. }
  251. /**
  252. * Returns current training epoch
  253. *
  254. * @return float
  255. */
  256. protected function getEpoch() {
  257. return($this->epoch);
  258. }
  259. /**
  260. * Sets debug state of learning progress
  261. *
  262. * @param bool $debugState
  263. *
  264. * @return void
  265. */
  266. public function setDebug($debugState = false) {
  267. $this->debug = $debugState;
  268. }
  269. /**
  270. * Performs network training progress
  271. *
  272. * @param array $trainStats
  273. *
  274. * @return string
  275. */
  276. public function visualizeTrain($trainStats) {
  277. $result = '';
  278. $chartData = array(0 => array(__('Epoch'), __('Error')));
  279. if (!empty($trainStats)) {
  280. foreach ($trainStats as $neuron => $neuronStats) {
  281. if (!empty($neuronStats)) {
  282. foreach ($neuronStats as $epoch => $error) {
  283. $chartData[] = array($epoch, $error);
  284. }
  285. }
  286. }
  287. $result .= wf_gchartsLine($chartData, __('Network training') . ' ' . $this->activationFunction, '100%', '400px', '');
  288. }
  289. return($result);
  290. }
  291. }