LogisticRegression.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. /* LogisticRegression.cpp
  2. *
  3. * Copyright (C) 2005-2012,2015-2018 Paul Boersma
  4. *
  5. * This code is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 2 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This code is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. * See the GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this work. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. #include "LogisticRegression.h"
  19. #include "../kar/UnicodeData.h"
  20. #include "oo_DESTROY.h"
  21. #include "LogisticRegression_def.h"
  22. #include "oo_COPY.h"
  23. #include "LogisticRegression_def.h"
  24. #include "oo_EQUAL.h"
  25. #include "LogisticRegression_def.h"
  26. #include "oo_CAN_WRITE_AS_ENCODING.h"
  27. #include "LogisticRegression_def.h"
  28. #include "oo_WRITE_TEXT.h"
  29. #include "LogisticRegression_def.h"
  30. #include "oo_WRITE_BINARY.h"
  31. #include "LogisticRegression_def.h"
  32. #include "oo_READ_TEXT.h"
  33. #include "LogisticRegression_def.h"
  34. #include "oo_READ_BINARY.h"
  35. #include "LogisticRegression_def.h"
  36. #include "oo_DESCRIPTION.h"
  37. #include "LogisticRegression_def.h"
  38. Thing_implement (LogisticRegression, Regression, 0);
  39. void structLogisticRegression :: v_info () {
  40. LogisticRegression_Parent :: v_info ();
  41. MelderInfo_writeLine (U"Dependent 1: ", our dependent1.get());
  42. MelderInfo_writeLine (U"Dependent 2: ", our dependent2.get());
  43. MelderInfo_writeLine (U"Interpretation:");
  44. MelderInfo_write (U" ln (P(", our dependent2.get(), U")/P(", our dependent1.get(), U")) " UNITEXT_ALMOST_EQUAL_TO U" ", Melder_fixed (intercept, 6));
  45. for (integer ivar = 1; ivar <= parameters.size; ivar ++) {
  46. RegressionParameter parm = parameters.at [ivar];
  47. MelderInfo_write (parm -> value < 0.0 ? U" - " : U" + ", Melder_fixed (fabs (parm -> value), 6), U" * ", parm -> label.get());
  48. }
  49. MelderInfo_writeLine (U"");
  50. MelderInfo_writeLine (U"Log odds ratios:");
  51. for (integer ivar = 1; ivar <= parameters.size; ivar ++) {
  52. RegressionParameter parm = parameters.at [ivar];
  53. MelderInfo_writeLine (U" Log odds ratio of factor ", parm -> label.get(), U": ", Melder_fixed ((parm -> maximum - parm -> minimum) * parm -> value, 6));
  54. }
  55. MelderInfo_writeLine (U"Odds ratios:");
  56. for (integer ivar = 1; ivar <= parameters.size; ivar ++) {
  57. RegressionParameter parm = parameters.at [ivar];
  58. MelderInfo_writeLine (U" Odds ratio of factor ", parm -> label.get(), U": ", exp ((parm -> maximum - parm -> minimum) * parm -> value));
  59. }
  60. }
  61. autoLogisticRegression LogisticRegression_create (conststring32 dependent1, conststring32 dependent2) {
  62. try {
  63. autoLogisticRegression me = Thing_new (LogisticRegression);
  64. Regression_init (me.get());
  65. my dependent1 = Melder_dup (dependent1);
  66. my dependent2 = Melder_dup (dependent2);
  67. return me;
  68. } catch (MelderError) {
  69. Melder_throw (U"LogisticRegression not created.");
  70. }
  71. }
  72. static autoLogisticRegression _Table_to_LogisticRegression (Table me, constINTVEC factors, integer dependent1, integer dependent2) {
  73. const integer numberOfFactors = factors.size;
  74. const integer numberOfParameters = numberOfFactors + 1;
  75. const integer numberOfCells = my rows.size;
  76. integer numberOfY0 = 0, numberOfY1 = 0, numberOfData = 0;
  77. double logLikelihood = 1e307, previousLogLikelihood = 1e308;
  78. if (numberOfParameters < 1) // includes intercept
  79. Melder_throw (U"Not enough columns (has to be more than 1).");
  80. /*
  81. * Divide up the contents of the table into a number of independent variables (x) and two dependent variables (y0 and y1).
  82. */
  83. autoNUMmatrix <double> x (1, numberOfCells, 0, numberOfFactors); // column 0 is the intercept
  84. autoNUMvector <double> y0 (1, numberOfCells);
  85. autoNUMvector <double> y1 (1, numberOfCells);
  86. autoNUMvector <double> meanX (1, numberOfFactors);
  87. autoNUMvector <double> stdevX (1, numberOfFactors);
  88. autoNUMmatrix <double> smallMatrix (0, numberOfFactors, 0, numberOfParameters);
  89. autoLogisticRegression thee = LogisticRegression_create (my columnHeaders [dependent1]. label.get(), my columnHeaders [dependent2]. label.get());
  90. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  91. double minimum = Table_getMinimum (me, factors [ivar]);
  92. double maximum = Table_getMaximum (me, factors [ivar]);
  93. Regression_addParameter (thee.get(), my columnHeaders [factors [ivar]]. label.get(), minimum, maximum, 0.0);
  94. }
  95. for (integer icell = 1; icell <= numberOfCells; icell ++) {
  96. y0 [icell] = Table_getNumericValue_Assert (me, icell, dependent1);
  97. y1 [icell] = Table_getNumericValue_Assert (me, icell, dependent2);
  98. numberOfY0 += y0 [icell];
  99. numberOfY1 += y1 [icell];
  100. numberOfData += y0 [icell] + y1 [icell];
  101. x [icell] [0] = 1.0; // intercept
  102. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  103. x [icell] [ivar] = Table_getNumericValue_Assert (me, icell, factors [ivar]);
  104. meanX [ivar] += x [icell] [ivar] * (y0 [icell] + y1 [icell]);
  105. }
  106. }
  107. if (numberOfY0 == 0 && numberOfY1 == 0)
  108. Melder_throw (U"No data in either class. Cannot determine result.");
  109. if (numberOfY0 == 0)
  110. Melder_throw (U"No data in class ", my columnHeaders [dependent1]. label.get(), U". Cannot determine result.");
  111. if (numberOfY1 == 0)
  112. Melder_throw (U"No data in class ", my columnHeaders [dependent2]. label.get(), U". Cannot determine result.");
  113. /*
  114. * Normalize the data.
  115. */
  116. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  117. meanX [ivar] /= numberOfData;
  118. for (integer icell = 1; icell <= numberOfCells; icell ++) {
  119. x [icell] [ivar] -= meanX [ivar];
  120. }
  121. }
  122. for (integer icell = 1; icell <= numberOfCells; icell ++) {
  123. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  124. stdevX [ivar] += x [icell] [ivar] * x [icell] [ivar] * (y0 [icell] + y1 [icell]);
  125. }
  126. }
  127. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  128. stdevX [ivar] = sqrt (stdevX [ivar] / numberOfData);
  129. for (integer icell = 1; icell <= numberOfCells; icell ++) {
  130. x [icell] [ivar] /= stdevX [ivar];
  131. }
  132. }
  133. /*
  134. * Initial state of iteration: the null model.
  135. */
  136. thy intercept = log ((double) numberOfY1 / (double) numberOfY0); // initial state of intercept: best guess for average log odds
  137. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  138. RegressionParameter parm = thy parameters.at [ivar];
  139. parm -> value = 0.0; // initial state of dependence: none
  140. }
  141. integer iteration = 1;
  142. for (; iteration <= 100; iteration ++) {
  143. previousLogLikelihood = logLikelihood;
  144. for (integer ivar = 0; ivar <= numberOfFactors; ivar ++) {
  145. for (integer jvar = ivar; jvar <= numberOfParameters; jvar ++) {
  146. smallMatrix [ivar] [jvar] = 0.0;
  147. }
  148. }
  149. /*
  150. * Compute the current log likelihood.
  151. */
  152. logLikelihood = 0.0;
  153. for (integer icell = 1; icell <= numberOfCells; icell ++) {
  154. double fittedLogit = thy intercept, fittedP, fittedQ, fittedLogP, fittedLogQ, fittedPQ, fittedVariance;
  155. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  156. RegressionParameter parm = thy parameters.at [ivar];
  157. fittedLogit += parm -> value * x [icell] [ivar];
  158. }
  159. /*
  160. * Basically we have fittedP = 1.0 / (1.0 + exp (- fittedLogit)),
  161. * but that works neither for fittedP values near 0 nor for values near 1.
  162. */
  163. if (fittedLogit > 15.0) {
  164. /*
  165. * For large fittedLogit, fittedLogP = ln (1/(1+exp(-fittedLogit))) = -ln (1+exp(-fittedLogit)) =~ - exp(-fittedLogit)
  166. */
  167. fittedLogP = - exp (- fittedLogit);
  168. fittedLogQ = - fittedLogit;
  169. fittedPQ = exp (- fittedLogit);
  170. fittedP = exp (fittedLogP);
  171. fittedQ = 1.0 - fittedP;
  172. } else if (fittedLogit < -15.0) {
  173. fittedLogP = fittedLogit;
  174. fittedLogQ = - exp (fittedLogit);
  175. fittedPQ = exp (fittedLogit);
  176. fittedP = exp (fittedLogP);
  177. fittedQ = 1 - fittedP;
  178. } else {
  179. fittedP = 1.0 / (1.0 + exp (- fittedLogit));
  180. fittedLogP = log (fittedP);
  181. fittedQ = 1.0 - fittedP;
  182. fittedLogQ = log (fittedQ);
  183. fittedPQ = fittedP * fittedQ;
  184. }
  185. logLikelihood += -2 * (y1 [icell] * fittedLogP + y0 [icell] * fittedLogQ);
  186. /*
  187. * Matrix shifting stuff.
  188. * Suppose a + b Sk + c Tk = ln (pk / qk),
  189. * where {a, b, c} are the coefficients to be optimized,
  190. * Sk and Tk are properties of stimulus k,
  191. * and pk and qk are the fitted probabilities for y1 and y0, respectively, given stimulus k.
  192. * Then ln pk = - ln (1 + qk / pk) = - ln (1 + exp (- (a + b Sk + c Tk)))
  193. * d ln pk / da = 1 / (1 + exp (a + b Sk + c Tk)) = qk
  194. * d ln pk / db = qk Sk
  195. * d ln pk / dc = qk Tk
  196. * d ln qk / da = - pk
  197. * Now LL = Sum(k) (y1k ln pk + y0k ln qk)
  198. * so that dLL/da = Sum(k) (y1k d ln pk / da + y0k ln qk / da) = Sum(k) (y1k qk - y0k pk)
  199. */
  200. fittedVariance = fittedPQ * (y0 [icell] + y1 [icell]);
  201. for (integer ivar = 0; ivar <= numberOfFactors; ivar ++) {
  202. /*
  203. * The last column gets the gradient of LL: dLL/da, dLL/db, dLL/dc.
  204. */
  205. smallMatrix [ivar] [numberOfParameters] += x [icell] [ivar] * (y1 [icell] * fittedQ - y0 [icell] * fittedP);
  206. for (integer jvar = ivar; jvar <= numberOfFactors; jvar ++) {
  207. smallMatrix [ivar] [jvar] += x [icell] [ivar] * x [icell] [jvar] * fittedVariance;
  208. }
  209. }
  210. }
  211. if (fabs (logLikelihood - previousLogLikelihood) < 1e-11) {
  212. break;
  213. }
  214. /*
  215. * Make matrix symmetric.
  216. */
  217. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  218. for (integer jvar = 0; jvar < ivar; jvar ++) {
  219. smallMatrix [ivar] [jvar] = smallMatrix [jvar] [ivar];
  220. }
  221. }
  222. /*
  223. * Invert matrix in the simplest way, and shift and wipe the last column with it.
  224. */
  225. for (integer ivar = 0; ivar <= numberOfFactors; ivar ++) {
  226. double pivot = smallMatrix [ivar] [ivar]; /* Save diagonal. */
  227. smallMatrix [ivar] [ivar] = 1.0;
  228. for (integer jvar = 0; jvar <= numberOfParameters; jvar ++) {
  229. smallMatrix [ivar] [jvar] /= pivot;
  230. }
  231. for (integer jvar = 0; jvar <= numberOfFactors; jvar ++) {
  232. if (jvar != ivar) {
  233. double temp = smallMatrix [jvar] [ivar];
  234. smallMatrix [jvar] [ivar] = 0.0;
  235. for (integer kvar = 0; kvar <= numberOfParameters; kvar ++) {
  236. smallMatrix [jvar] [kvar] -= temp * smallMatrix [ivar] [kvar];
  237. }
  238. }
  239. }
  240. }
  241. /*
  242. * Update the parameters from the last column of smallMatrix.
  243. */
  244. thy intercept += smallMatrix [0] [numberOfParameters];
  245. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  246. RegressionParameter parm = thy parameters.at [ivar];
  247. parm -> value += smallMatrix [ivar] [numberOfParameters];
  248. }
  249. }
  250. if (iteration > 100) {
  251. Melder_warning (U"Logistic regression has not converged in 100 iterations. The results are unreliable.");
  252. }
  253. for (integer ivar = 1; ivar <= numberOfFactors; ivar ++) {
  254. RegressionParameter parm = thy parameters.at [ivar];
  255. parm -> value /= stdevX [ivar];
  256. thy intercept -= parm -> value * meanX [ivar];
  257. }
  258. return thee;
  259. }
  260. autoLogisticRegression Table_to_LogisticRegression (Table me, conststring32 factors_columnLabelString,
  261. conststring32 dependent1_columnLabel, conststring32 dependent2_columnLabel)
  262. {
  263. try {
  264. auto factors_columnIndices = Table_getColumnIndicesFromColumnLabelString (me, factors_columnLabelString);
  265. integer dependent1_columnIndex = Table_getColumnIndexFromColumnLabel (me, dependent1_columnLabel);
  266. integer dependent2_columnIndex = Table_getColumnIndexFromColumnLabel (me, dependent2_columnLabel);
  267. autoLogisticRegression thee = _Table_to_LogisticRegression (me, factors_columnIndices.get(), dependent1_columnIndex, dependent2_columnIndex);
  268. return thee;
  269. } catch (MelderError) {
  270. Melder_throw (me, U": logistic regression not performed.");
  271. }
  272. }
  273. static inline double NUMmin2 (double a, double b) {
  274. return a < b ? a : b;
  275. }
  276. static inline double NUMmax2 (double a, double b) {
  277. return a > b ? a : b;
  278. }
  279. void LogisticRegression_drawBoundary (LogisticRegression me, Graphics graphics, integer colx, double xleft, double xright,
  280. integer coly, double ybottom, double ytop, bool garnish)
  281. {
  282. RegressionParameter parmx = my parameters.at [colx];
  283. RegressionParameter parmy = my parameters.at [coly];
  284. if (xleft == xright) {
  285. xleft = parmx -> minimum;
  286. xright = parmx -> maximum;
  287. }
  288. if (ybottom == ytop) {
  289. ybottom = parmy -> minimum;
  290. ytop = parmy -> maximum;
  291. }
  292. double intercept = my intercept;
  293. for (integer iparm = 1; iparm <= my parameters.size; iparm ++) {
  294. if (iparm != colx && iparm != coly) {
  295. RegressionParameter parm = my parameters.at [iparm];
  296. intercept += parm -> value * (0.5 * (parm -> minimum + parm -> maximum));
  297. }
  298. }
  299. Graphics_setInner (graphics);
  300. Graphics_setWindow (graphics, xleft, xright, ybottom, ytop);
  301. double xbottom = (intercept + parmy -> value * ybottom) / - parmx -> value;
  302. double xtop = (intercept + parmy -> value * ytop) / - parmx -> value;
  303. double yleft = (intercept + parmx -> value * xleft) / - parmy -> value;
  304. double yright = (intercept + parmx -> value * xright) / - parmy -> value;
  305. double xmin = NUMmin2 (xleft, xright), xmax = NUMmax2 (xleft, xright);
  306. double ymin = NUMmin2 (ybottom, ytop), ymax = NUMmax2 (ybottom, ytop);
  307. trace (U"LogisticRegression_drawBoundary: ",
  308. xmin, U" ", xmax, U" ", xbottom, U" ", xtop, U" ", ymin, U" ", ymax, U" ", yleft, U" ", yright);
  309. if (xbottom >= xmin && xbottom <= xmax) { // line goes through bottom?
  310. if (xtop >= xmin && xtop <= xmax) // line goes through top?
  311. Graphics_line (graphics, xbottom, ybottom, xtop, ytop); // draw from bottom to top
  312. else if (yleft >= ymin && yleft <= ymax) // line goes through left?
  313. Graphics_line (graphics, xbottom, ybottom, xleft, yleft); // draw from bottom to left
  314. else if (yright >= ymin && yright <= ymax) // line goes through right?
  315. Graphics_line (graphics, xbottom, ybottom, xright, yright); // draw from bottom to right
  316. } else if (yleft >= ymin && yleft <= ymax) { // line goes through left?
  317. if (yright >= ymin && yright <= ymax) // line goes through right?
  318. Graphics_line (graphics, xleft, yleft, xright, yright); // draw from left to right
  319. else if (xtop >= xmin && xtop <= xmax) // line goes through top?
  320. Graphics_line (graphics, xleft, yleft, xtop, ytop); // draw from left to top
  321. } else if (xtop >= xmin && xtop <= xmax) { // line goes through top?
  322. if (yright >= ymin && yright <= ymax) // line goes through right?
  323. Graphics_line (graphics, xtop, ytop, xright, yright); // draw from top to right
  324. }
  325. Graphics_unsetInner (graphics);
  326. if (garnish) {
  327. Graphics_drawInnerBox (graphics);
  328. Graphics_textBottom (graphics, true, parmx -> label.get());
  329. Graphics_marksBottom (graphics, 2, true, true, false);
  330. Graphics_textLeft (graphics, true, parmy -> label.get());
  331. Graphics_marksLeft (graphics, 2, true, true, false);
  332. }
  333. }
  334. /*
  335. autoTable Table_LogisticRegression_addProbabilities (Table me, LogisticRegression thee) {
  336. for (icell = 1; icell <= numberOfCells; icell ++) {
  337. double fittedLogit = parameters [0], fittedP, fittedQ, fittedLogP, fittedLogQ;
  338. for (ivar = 1; ivar <= numberOfIndependentVariables; ivar ++) {
  339. fittedLogit += parameters [ivar] * Table_getNumericValue_Assert (me, icell, ivar);
  340. }
  341. if (fittedLogit > 15.0) {
  342. fittedLogP = - exp (- fittedLogit);
  343. fittedLogQ = - fittedLogit;
  344. fittedP = exp (fittedLogP);
  345. fittedQ = 1.0 - fittedP;
  346. } else if (fittedLogit < -15.0) {
  347. fittedLogP = fittedLogit;
  348. fittedLogQ = - exp (fittedLogit);
  349. fittedP = exp (fittedLogP);
  350. fittedQ = 1 - fittedP;
  351. } else {
  352. fittedP = 1.0 / (1.0 + exp (- fittedLogit));
  353. fittedLogP = log (fittedP);
  354. fittedQ = 1.0 - fittedP;
  355. fittedLogQ = log (fittedQ);
  356. }
  357. Table_setNumericValue (thee, icell, numberOfIndependentVariables + 1, fittedQ);
  358. Table_setNumericValue (thee, icell, numberOfIndependentVariables + 2, fittedP);
  359. }
  360. }
  361. */
  362. /* End of file LogisticRegression.cpp */