SVMUtil.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. // Copyright 2008 Rarefied Technologies, Inc.
  2. // Distributed under the GPL v2 please see
  3. // LICENSE file for more information.
  4. #include "SVMUtil.h"
  5. #include <string>
  6. #include <fstream>
  7. #include <algorithm>
  8. #include <cfloat>
  9. #include <cmath>
  10. #include <iostream>
  11. #include "parametersearch.h"
  12. #include "parameterresult.h"
  13. #include "../thirdparty/libsvm/svm.h"
  14. #include "../thirdparty/boost/serialization/set.hpp"
  15. void myfunction(int);
  16. //#define SCALED_MAX 1
  17. //#define SCALED_MIN 0 //assume features > 0
  18. //#define SCALED_MIN -1 // if features can be < 0
  19. //svm_problem ParseTrainingFile(string strFilename);
  20. #define MAX_LINE_LENGTH 1024
  21. SVMUtil::SVMUtil()
  22. {
  23. m_pProblem = NULL;
  24. m_pScaleFactors = NULL;
  25. m_nParams = 0;
  26. m_pModel = NULL;
  27. }
  28. SVMUtil::~SVMUtil()
  29. {
  30. // problem destroyed when the model is
  31. //delete m_pProblem;
  32. //m_pProblem = NULL;
  33. if(m_pModel)
  34. svm_destroy_model(m_pModel);
  35. }
  36. // borrowed from read_problem in svm-train.c
  37. svm_problem* SVMUtil::ParseTrainingFile(std::string strFilename)
  38. {
  39. m_pProblem = new svm_problem;
  40. svm_node *x_space;
  41. svm_parameter param;
  42. const char* filename = strFilename.c_str();
  43. int elements, i, j;
  44. FILE *fp = fopen(filename,"r");
  45. if(fp == NULL)
  46. {
  47. fprintf(stderr,"can't open input file %s\n",filename);
  48. exit(1);
  49. }
  50. m_pProblem->l = 0;
  51. elements = 0;
  52. while(1)
  53. {
  54. int c = fgetc(fp);
  55. switch(c)
  56. {
  57. case '\n':
  58. ++m_pProblem->l;
  59. // fall through,
  60. // count the '-1' element
  61. case ':':
  62. ++elements;
  63. break;
  64. case EOF:
  65. goto out;
  66. default:
  67. ;
  68. }
  69. }
  70. out:
  71. rewind(fp);
  72. m_pProblem->y = Malloc(double,m_pProblem->l);
  73. m_pProblem->x = Malloc(struct svm_node *,m_pProblem->l);
  74. int nParamCountGuess = elements / m_pProblem->l;
  75. m_nParams = 0;
  76. for(i=0;i<m_pProblem->l;i++)
  77. {
  78. double label;
  79. x_space = Malloc(struct svm_node, nParamCountGuess+1);
  80. m_pProblem->x[i] = x_space;
  81. fscanf(fp,"%lf",&label);
  82. m_pProblem->y[i] = label;
  83. j=0;
  84. while(1)
  85. {
  86. int c;
  87. do {
  88. c = getc(fp);
  89. if(c=='\n') goto out2;
  90. } while(isspace(c));
  91. ungetc(c,fp);
  92. int nIndex;
  93. double dValue;
  94. if (fscanf(fp,"%d:%lf", &nIndex, &dValue) < 2)
  95. {
  96. fprintf(stderr,"Wrong input format at line %d\n", i+1);
  97. exit(1);
  98. }
  99. if(dValue!=0)
  100. {
  101. x_space[j].index = nIndex;
  102. x_space[j].value = dValue;
  103. ++j;
  104. }
  105. }
  106. out2:
  107. if(j>=1 && x_space[j-1].index > m_nParams)
  108. m_nParams = x_space[j-1].index;
  109. x_space[j++].index = -1;
  110. }
  111. if(param.gamma == 0)
  112. param.gamma = 1.0/m_nParams;
  113. if(param.kernel_type == PRECOMPUTED)
  114. for(i=0;i<m_pProblem->l;i++)
  115. {
  116. if (m_pProblem->x[i][0].index != 0)
  117. {
  118. fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
  119. exit(1);
  120. }
  121. if ((int)m_pProblem->x[i][0].value <= 0 || (int)m_pProblem->x[i][0].value > m_nParams)
  122. {
  123. fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
  124. exit(1);
  125. }
  126. }
  127. fclose(fp);
  128. ScaleTrainingData();
  129. SaveScaleFactors(strFilename + ".sf");
  130. return m_pProblem;
  131. }
  132. bool SVMUtil::ScaleTrainingData()
  133. {
  134. if(!m_pProblem)
  135. {
  136. assert(0);
  137. return false;
  138. }
  139. if(!DetermineScaleFactors())
  140. return false;
  141. svm_node* pNode = NULL;
  142. for(int i=0; i < m_pProblem->l; i++)
  143. {
  144. pNode = m_pProblem->x[i];
  145. ScaleNode(pNode);
  146. }
  147. return true;
  148. }
  149. bool SVMUtil::DetermineScaleFactors()
  150. {
  151. if(!m_pProblem)
  152. return false;
  153. svm_node* pNode = NULL;
  154. double* pMaxValue = Malloc(double, m_nParams);
  155. m_pScaleFactors = Malloc(double, m_nParams);
  156. for(int j=0; j < m_nParams; j++)
  157. {
  158. pMaxValue[j] = 0; // assumes values should be scaled between 0 and 1
  159. }
  160. for(int i=0; i < m_pProblem->l; i++)
  161. {
  162. pNode = m_pProblem->x[i];
  163. int nindex = 0;
  164. int j=0;
  165. while(pNode)
  166. {
  167. nindex = pNode[j].index;
  168. if(nindex==-1)
  169. break;
  170. pMaxValue[nindex-1] = max(pMaxValue[nindex-1], pNode[j].value); // assume values are positive
  171. j++;
  172. }
  173. }
  174. for(int j=0; j < m_nParams; j++)
  175. {
  176. if(pMaxValue[j] > 0)
  177. m_pScaleFactors[j] = (double)1./pMaxValue[j];
  178. else
  179. m_pScaleFactors[j] = 1;
  180. }
  181. return true;
  182. }
  183. bool SVMUtil::ScaleNode(svm_node* pNode)
  184. {
  185. if(!pNode)
  186. {
  187. cerr << "error scaling" << endl;
  188. assert(0);
  189. return false;
  190. }
  191. if(!m_pScaleFactors)
  192. {
  193. if(m_pProblem)
  194. DetermineScaleFactors();
  195. else
  196. {
  197. assert(0);
  198. return false;
  199. }
  200. }
  201. int i = 0;
  202. while(pNode[i].index != -1)
  203. {
  204. pNode[i].value *= m_pScaleFactors[pNode[i].index-1];
  205. i++;
  206. }
  207. return true;
  208. }
  209. bool SVMUtil::ParameterSearch(svm_parameter* pSvmParam, string strFilename)
  210. {
  211. if(!m_pProblem || !pSvmParam)
  212. return false;
  213. //struct sigaction sa;
  214. //sa.sa_handler = &myfunction;
  215. //sigaction(SIGINT, &sa, NULL);
  216. CParameterSearch* paramSearch = new CParameterSearch(m_pProblem, pSvmParam, strFilename);
  217. //SaveSearch(paramSearch);
  218. delete paramSearch;
  219. return true;
  220. }
  221. bool SVMUtil::SaveSearch(const CParameterSearch* p_Search)
  222. {
  223. std::ofstream ofs("searchResults.txt");
  224. boost::archive::text_oarchive oa(ofs);
  225. oa << *p_Search;
  226. return true;
  227. }
  228. void myfunction(int number)
  229. {
  230. int ten;
  231. int five = number;
  232. ten = five + five;
  233. }
  234. bool SVMUtil::Load(string filename)
  235. {
  236. if(m_pModel)
  237. svm_destroy_model(m_pModel);
  238. string modelname = filename + ".mod";
  239. string scalename = filename + ".sf";
  240. m_pModel = svm_load_model(modelname.c_str());
  241. LoadScaleFactors(scalename);
  242. return true;
  243. }
  244. bool SVMUtil::Save(string filename)
  245. {
  246. string modelname = filename + ".mod";
  247. string scalename = filename + ".sf";
  248. svm_save_model(modelname.c_str(), m_pModel);
  249. SaveScaleFactors(scalename);
  250. return true;
  251. }
  252. void SVMUtil::SaveScaleFactors(string filename)
  253. {
  254. ofstream fout;
  255. fout.open(filename.c_str());
  256. fout << m_nParams << endl;
  257. for(int i=0; i<m_nParams; i++)
  258. {
  259. fout << m_pScaleFactors[i] << endl;
  260. }
  261. fout.close();
  262. }
  263. bool SVMUtil::LoadScaleFactors(string filename)
  264. {
  265. ifstream fin;
  266. fin.open(filename.c_str());
  267. fin >> m_nParams;
  268. if(m_pScaleFactors)
  269. delete m_pScaleFactors;
  270. m_pScaleFactors = new double[m_nParams];
  271. for(int i=0; i<m_nParams; i++)
  272. {
  273. fin >> m_pScaleFactors[i];
  274. }
  275. return true;
  276. }
  277. bool SVMUtil::CrossValidate(int nFolds, svm_parameter* pParam)
  278. {
  279. double* target = new double[m_pProblem->l];
  280. if(nFolds>0)
  281. svm_cross_validation(m_pProblem, pParam, nFolds, target);
  282. else
  283. {
  284. if(!m_pModel)
  285. m_pModel = svm_train(m_pProblem, pParam);
  286. for(int i=0; i<m_pProblem->l; i++)
  287. {
  288. target[i] = svm_predict(m_pModel, m_pProblem->x[i]);
  289. }
  290. }
  291. float fError = 0;
  292. float fWrong = 0;
  293. for(int i=0; i<m_pProblem->l; i++)
  294. {
  295. fError += abs(m_pProblem->y[i] - target[i]) ;
  296. if( m_pProblem->y[i] >= 0.5 && target[i] < 0.5)
  297. fWrong++;
  298. else if( m_pProblem->y[i] < 0.5 && target[i] >= 0.5)
  299. fWrong++;
  300. }
  301. fError = (float)fError/m_pProblem->l;
  302. fWrong = (float) fWrong/m_pProblem->l;
  303. float fStdDev = 0;
  304. for(int i=0; i<m_pProblem->l; i++)
  305. {
  306. fStdDev += pow(fError - abs(m_pProblem->y[i] - target[i]), 2) ;
  307. }
  308. fStdDev = pow(fStdDev, (float)0.5) / m_pProblem->l;
  309. std::cout << "Percent wrong: " << fWrong << " Avg Error: " << fError << " Std Dev: " << fStdDev << std::endl;
  310. return true;
  311. }