ferns.R 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # R part of rFerns
  2. #
  3. # Copyright 2011-2018 Miron B. Kursa
  4. #
  5. # This file is part of rFerns R package.
  6. #
  7. #rFerns is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
  8. #rFerns is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
  9. #You should have received a copy of the GNU General Public License along with rFerns. If not, see http://www.gnu.org/licenses/.
  10. #' @rdname rFerns
  11. #' @export
  12. rFerns<-function(x,...)
  13. UseMethod("rFerns")
  14. #' @rdname rFerns
  15. #' @method rFerns formula
  16. #' @param formula alternatively, formula describing model to be analysed.
  17. #' @param data in which to interpret formula.
  18. #' @export
  19. rFerns.formula<-function(formula,data=.GlobalEnv,...){
  20. #Convert formula into a data frame
  21. stats::terms.formula(formula,data=data)->t
  22. rx<-eval(attr(t,"variables"),data)
  23. apply(attr(t,"factors"),1,sum)>0->sel
  24. nam<-rownames(attr(t,"factors"))[sel]
  25. data.frame(rx[sel])->x; names(x)<-nam
  26. rx[[attr(t,"response")]]->y
  27. #Pass to the default method
  28. rFerns.default(x,y,...)
  29. }
  30. #' @rdname rFerns
  31. #' @method rFerns matrix
  32. #' @export
  33. rFerns.matrix<-function(x,y,...){
  34. #If the input is matrix, data.frame it first
  35. rFerns.default(data.frame(x),y,...)
  36. }
  37. #' Classification with random ferns
  38. #'
  39. #' This function builds a random ferns model on the given training data.
  40. #' @rdname rFerns
  41. #' @method rFerns default
  42. #' @param x Data frame containing attributes; must have unique names and contain only numeric, integer or (ordered) factor columns.
  43. #' Factors must have less than 31 levels. No \code{NA} values are permitted.
  44. #' @param y A decision vector. Must a factor of the same length as \code{nrow(X)} for ordinary many-label classification, or a logical matrix with each column corresponding to a class for multi-label classification.
  45. #' @param depth The depth of the ferns; must be in 1--16 range. Note that time and memory requirements scale with \code{2^depth}.
  46. #' @param ferns Number of ferns to be build.
  47. #' @param importance Set to calculate attribute importance measure (VIM);
  48. #' \code{"simple"} will calculate the default mean decrease of true class score (MDTS, something similar to Random Forest's MDA/MeanDecreaseAccuracy),
  49. #' \code{"shadow"} will calculate MDTS and additionally MDTS of this attribute shadow, an implicit feature build by shuffling values within it, thus stripping it from information (which is slightly slower).
  50. #' Shadow importance is useful as a reference to judge significance of a regular importance.
  51. #' \code{"none"} turns importance calculation off, for a slightly faster execution.
  52. #' For compatibility with pre-1.2 rFerns, \code{TRUE} will resolve to \code{"simple"} and \code{FALSE} to \code{"none"}.
  53. #' Abbreviation can be used instead of a full value.
  54. #' @param saveForest Should the model be saved? It must be \code{TRUE} if you want to use the model for prediction; however, if you are interested in importance or OOB error only, setting it to \code{FALSE} significantly improves memory requirements, especially for large \code{depth} and \code{ferns}.
  55. #' @param consistentSeed PRNG seed used for shadow importance \emph{only}.
  56. #' Must be either a 2-element integer vector or \code{NULL}, which corresponds to seeding from the default PRNG.
  57. #' @param threads Number or OpenMP threads to use. The default value of \code{0} means all available to OpenMP.
  58. #' It should be set to the same value in two merged models to make shadow importance meaningful.
  59. #' @param ... For formula and matrix methods, a place to state parameters to be passed to default method.
  60. #' For the print method, arguments to be passed to \code{print}.
  61. #' @return An object of class \code{rFerns}, which is a list with the following components:
  62. #' \item{model}{The built model; \code{NULL} if \code{saveForest} was \code{FALSE}.}
  63. #' \item{oobErr}{OOB approximation of accuracy.
  64. #' Ignores never-OOB-tested objects (see oobScores element).}
  65. #' \item{importance}{The importance scores or \code{NULL} if \code{importance} was set to \code{"none"}.
  66. #' In a first case it is a \code{data.frame} with two or three columns:
  67. #' \code{MeanScoreLoss} which is a mean decrease of a score of a correct class when a certain attribute is permuted,
  68. #' \code{Tries} which is number of ferns which utilised certain attribute, and, only when \code{importance} was set to \code{"shadow"},
  69. #' \code{Shadow}, which is a mean decrease of accuracy for the correct class for a permuted copy of an attribute (useful as a baseline for normal importance).
  70. #' The \code{rownames} are set and equal to the \code{names(x)}.}
  71. #' \item{oobScores}{A matrix of OOB scores of each class for each object in training set.
  72. #' Rows correspond to classes in the same order as in \code{levels(Y)}.
  73. #' If the \code{ferns} is too small, some columns may contain \code{NA}s, what means that certain objects were never in test set.}
  74. #' \item{oobPreds}{A vector of OOB predictions of class for each object in training set. Never-OOB-tested objects (see above) have predictions equal to \code{NA}.}
  75. #' \item{oobConfusionMatrix}{Confusion matrix build from \code{oobPreds} and \code{y}.}
  76. #' \item{timeTaken}{Time used to train the model (smaller than wall time because data preparation and model final touches are excluded; however it includes the time needed to compute importance, if it applies).
  77. #' An object of \code{difftime} class.}
  78. #' \item{parameters}{Numerical vector of three elements: \code{classes}, \code{depth} and \code{ferns}, containing respectively the number of classes in decision and copies of \code{depth} and \code{ferns} parameters.}
  79. #' \item{classLabels}{Copy of \code{levels(Y)} after purging unused levels.}
  80. #' \item{consistentSeed}{Consistent seed used; only present for \code{importance="shadow"}.
  81. #' Can be used to seed a new model via \code{consistentSeed} argument.}
  82. #' \item{isStruct}{Copy of the train set structure, required internally by predict method.}
  83. #' @note The unused levels of the decision will be removed; on the other hand unused levels of categorical attributes will be preserved, so that they could be present in the data later predicted with the model.
  84. #' The levels of ordered factors in training and predicted data must be identical.
  85. #'
  86. #' Do not use formula interface for a data with large number of attributes; the overhead from handling the formula may be significant.
  87. #' @references Ozuysal M, Calonder M, Lepetit V & Fua P. (2009). \emph{Fast Keypoint Recognition using Random Ferns}, IEEE Transactions on Pattern Analysis and Machine Intelligence, 32(3), 448-461.
  88. #'
  89. #' Kursa MB (2014). \emph{rFerns: An Implementation of the Random Ferns Method for General-Purpose Machine Learning}, Journal of Statistical Software, 61(10), 1-13.
  90. #' @examples
  91. #' set.seed(77)
  92. #' #Fetch Iris data
  93. #' data(iris)
  94. #' #Build model
  95. #' rFerns(Species~.,data=iris)
  96. #' ##Importance
  97. #' rFerns(Species~.,data=iris,importance="shadow")->model
  98. #' print(model$imp)
  99. #' @export
  100. #' @useDynLib rFerns, .registration=TRUE
  101. rFerns.default<-function(x,y,depth=5,ferns=1000,importance="none",saveForest=TRUE,consistentSeed=NULL,threads=0,...){
  102. #Stop on bad input
  103. depth<-as.integer(depth)
  104. ferns<-as.integer(ferns)
  105. stopifnot(length(depth)==1 && depth>0 && depth<=16)
  106. stopifnot(length(ferns)==1 && ferns>0)
  107. stopifnot(!any(is.na(y)))
  108. if(!is.data.frame(x)) stop("x must be a data frame.")
  109. if(is.na(names(x)) || any(duplicated(names(x)))) stop("Attribute names must be unique.")
  110. if(is.factor(y) && is.null(dim(y))){
  111. multi<-FALSE
  112. if(length(y)!=nrow(x)) stop("Attributes' and decision's sizes must match.")
  113. }else{
  114. y<-as.matrix(y)
  115. if(is.logical(y) && length(dim(y))==2){
  116. multi<-TRUE
  117. if(nrow(y)!=nrow(x)) stop("Attributes' and decision's sizes must match.")
  118. }else{
  119. stop("y must be a factor vector or a logical matrix.")
  120. }
  121. }
  122. if(!all(sapply(x,function(j) any(class(j)%in%c("numeric","integer","factor","ordered")))))
  123. stop("All attributes must be either numeric or factor.")
  124. if(any((sapply(x,function(a) ((length(levels(a))>30)&&(!is.ordered(a)))))->bad)){
  125. stop(sprintf(
  126. "Attribute(s) %s is/are unordered factor(s) with above 30 levels. Split or convert to ordered.",
  127. paste(names(x)[bad],collapse=", ")))
  128. }
  129. #Backward compatibility
  130. if(length(importance)!=1) stop("Wrong importance value.")
  131. if(identical(importance,FALSE)) importance<-"none"
  132. if(identical(importance,TRUE)) importance<-"simple"
  133. importance<-pmatch(importance,c("none","simple","shadow"))
  134. if(is.na(importance)) stop("Wrong importance value.")
  135. #Consistent seed setup
  136. if(importance==3){
  137. if(is.null(consistentSeed)){
  138. consistentSeed<-as.integer(sample(2^32-1,2,replace=TRUE)-2^31)
  139. }
  140. stopifnot(is.integer(consistentSeed))
  141. stopifnot(length(consistentSeed)==2)
  142. }else{
  143. if(!is.null(consistentSeed)){
  144. warning("Consistent seed is only useful with shadow importance; dropping.")
  145. consistentSeed<-NULL
  146. }
  147. }
  148. if(multi && (importance>1)) stop("Importance is not yet supported for multi-label ferns.")
  149. Sys.time()->before
  150. .Call(random_ferns,x,y,
  151. as.integer(depth[1]),
  152. as.integer(ferns[1]),
  153. as.integer(importance-1), #0->none, 1->msl, 2->msl+sha
  154. as.integer(saveForest),
  155. as.integer(multi),
  156. as.integer(consistentSeed),
  157. as.integer(threads))->ans
  158. after<-Sys.time()
  159. #Adjust C output with proper factor levels
  160. if(!is.null(ans$oobPreds)){
  161. ans$oobPreds<-factor(ans$oobPreds,
  162. levels=0:(length(levels(y))-1),
  163. labels=levels(y))
  164. }else{
  165. if(multi){
  166. ans$oobPreds<-t(ans$oobScores>0)
  167. colnames(ans$oobPreds)<-colnames(y)
  168. }
  169. }
  170. if(!multi){
  171. ans$classLabels<-levels(y)
  172. }else{
  173. ans$classLabels<-colnames(y)
  174. }
  175. if(saveForest){
  176. ans$isStruct<-list()
  177. lapply(x,levels)->ans$isStruct$predictorLevels
  178. sapply(x,is.integer)->ans$isStruct$integerPredictors
  179. sapply(x,is.ordered)->ans$isStruct$orderedFactorPredictors
  180. }
  181. if(!multi){
  182. table(Predicted=ans$oobPreds,True=y)->ans$oobConfusionMatrix
  183. if(is.null(ans$oobErr))
  184. ans$oobErr<-mean(ans$oobPreds!=y,na.rm=TRUE)
  185. ans$parameters<-c(classes=length(levels(y)),depth=depth,ferns=ferns)
  186. ans$type<-"class-many"
  187. }else{
  188. NULL->ans$oobConfusionMatrix
  189. if(is.null(ans$oobErr))
  190. ans$oobErr<-mean(rowSums(y!=ans$oobPreds))
  191. ans$oobPerClassError<-colMeans(ans$oobPreds!=y)
  192. ans$parameters<-c(classes=ncol(y),depth=depth,ferns=ferns)
  193. ans$type<-"class-multi"
  194. }
  195. if(!is.null(ans$importance)){
  196. if(importance==2){
  197. ans$importance<-data.frame(matrix(ans$importance,ncol=2))
  198. names(ans$importance)<-c("MeanScoreLoss","Tries")
  199. }
  200. if(importance==3){
  201. ans$importance<-data.frame(matrix(ans$importance,ncol=3))
  202. names(ans$importance)<-c("MeanScoreLoss","Shadow","Tries")
  203. ans$consistentSeed<-consistentSeed
  204. }
  205. if(!is.null(names(x)))
  206. rownames(ans$importance)<-names(x)
  207. }
  208. #Calculate time taken by the calculation
  209. ans$timeTaken<-after-before
  210. class(ans)<-"rFerns"
  211. return(ans)
  212. }
  213. #' Prediction with random ferns model
  214. #'
  215. #' This function predicts classes of new objects with given \code{rFerns} object.
  216. #' @method predict rFerns
  217. #' @param object Object of a class \code{rFerns}; a model that will be used for prediction.
  218. #' @param x Data frame containing attributes; must have corresponding names to training set (although order is not important) and do not introduce new factor levels.
  219. #' If this argument is not given, OOB predictions on the training set will be returned.
  220. #' @param scores If \code{TRUE}, the result will contain score matrix instead of simple predictions.
  221. #' @param ... Additional parameters.
  222. #' @return Predictions.
  223. #' If \code{scores} is \code{TRUE}, a factor vector (for many-class classification) or a logical data.frame (for multi-class classification) with predictions, else a data.frame with class' scores.
  224. #' @examples
  225. #' set.seed(77)
  226. #' #Fetch Iris data
  227. #' data(iris)
  228. #' #Split into tRain and tEst set
  229. #' iris[c(TRUE,FALSE),]->irisR
  230. #' iris[c(FALSE,TRUE),]->irisE
  231. #' #Build model
  232. #' rFerns(Species~.,data=irisR)->model
  233. #' print(model)
  234. #'
  235. #' #Test
  236. #' predict(model,irisE)->p
  237. #' print(table(
  238. #' Predictions=p,
  239. #' True=irisE[["Species"]]))
  240. #' err<-mean(p!=irisE[["Species"]])
  241. #' print(paste("Test error",err,sep=" "))
  242. #'
  243. #' #Show first OOB scores
  244. #' head(predict(model,scores=TRUE))
  245. #' @export
  246. predict.rFerns<-function(object,x,scores=FALSE,...){
  247. #Validate input
  248. if(!("rFerns"%in%class(object))) stop("object must be of a rFerns class")
  249. if(is.null(object$model)&(!missing(x)))
  250. stop("This fern forest object does not contain the model.")
  251. scores<-as.logical(scores)[1]
  252. if(is.na(scores)) stop("Wrong value of scores; should be TRUE or FALSE.")
  253. iss<-object$isStruct
  254. if(is.null(iss)){
  255. #object is a v0.1 rFerns
  256. object$isStruct$predictorLevels<-object$predictorLevels
  257. object$isStruct$integerPredictors<-
  258. object$isStruct$orderedFactorPredictors<-
  259. rep(FALSE,length(iss$predictorLevels))
  260. }
  261. iss$predictorLevels->pL
  262. pN<-names(pL)
  263. multi<-identical(object$type,"class-multi")
  264. if(missing(x)){
  265. #There is no x; return the OOB predictions, nicely formatted
  266. if(scores){
  267. data.frame(t(object$oobScores))->ans
  268. object$classLabels->names(ans)
  269. return(ans)
  270. }else{
  271. if(multi){
  272. return(data.frame(object$oobPreds))
  273. }else{
  274. return(object$oobPreds)
  275. }
  276. }
  277. }
  278. if(!identical(names(x),pN)){
  279. #Restore x state from training based on x's names
  280. if(!all(pN%in%names(x))){
  281. stop("Some training attributes missing in test.")
  282. }
  283. x[,pN]->x
  284. }
  285. #Fail for NAs in input
  286. if(any(is.na(x))) stop("NAs in predictors.")
  287. for(e in 1:ncol(x))
  288. if(is.null(pL[[e]])){
  289. if(iss$integerPredictors[e]){
  290. if(!("integer"%in%class(x[,e]))) stop(sprintf("Attribute %s should be integer.",pN[e]))
  291. }else{
  292. if(!("numeric"%in%class(x[,e]))) stop(sprintf("Attribute %s should be numeric.",pN[e]))
  293. }
  294. }else{
  295. if(iss$orderedFactorPredictors[e]){
  296. #Check if given attribute is also ordered
  297. if(!is.ordered(x[,e])) stop(sprintf("Attribute %s should be an ordered factor.",pN[e]))
  298. #Convert levels
  299. if(!identical(levels(x[,e]),pL[[e]]))
  300. stop(sprintf("Levels of %s does not match those from training (%s).",pN[e],paste(pL[[e]],collapse=", ")))
  301. }else{
  302. #Convert factor levels to be compatible with training
  303. if(!identical(levels(x[,e]),pL[[e]]))
  304. x[,e]<-factor(x[,e],levels=pL[[e]])
  305. #In case of mismatch, NAs will appear -- catch 'em and fail
  306. if(any(is.na(x[,e]))) stop(sprintf("Levels of %s does not match those from training (%s).",pN[e],paste(pL[[e]],collapse=", ")))
  307. }
  308. }
  309. #Prediction itself
  310. Sys.time()->before
  311. .Call(random_ferns_predict,x,
  312. object$model,
  313. as.integer(object$parameters["depth"]),
  314. as.integer(object$parameters["ferns"]),
  315. as.integer(length(object$classLabels)),
  316. as.integer(scores),as.integer(multi))->ans
  317. after<-Sys.time()
  318. if(scores){
  319. ans<-data.frame(matrix(ans,ncol=length(object$classLabels),byrow=TRUE)/
  320. object$parameters["ferns"])
  321. object$classLabels->names(ans)
  322. }else{
  323. if(!multi){
  324. ans<-factor(ans,levels=0:(length(object$classLabels)-1),
  325. labels=object$classLabels)
  326. }else{
  327. ans<-data.frame(matrix(ans,ncol=length(object$classLabels),byrow=TRUE)>0)
  328. object$classLabels->names(ans)
  329. }
  330. }
  331. #Store the timing
  332. attr(ans,"timeTaken")<-after-before
  333. return(ans)
  334. }
  335. #' @method print rFerns
  336. #' @export
  337. print.rFerns<-function(x,...){
  338. #Pretty-print rFerns output
  339. cat(sprintf("\n Forest of %s %sferns of a depth %s.\n\n",
  340. x$parameters["ferns"],
  341. ifelse(identical(x$type,"class-multi"),"multi-class ",""),
  342. x$parameters["depth"]))
  343. if(identical(x$type,"class-multi")){
  344. if(!is.null(x$oobErr))
  345. cat(sprintf(" OOB Hamming distance %0.3f for %d classes.\n",
  346. utils::tail(x$oobErr,1),
  347. x$parameters["classes"]))
  348. if(!is.null(x$oobPerClassError)){
  349. cat(" Per-class error rates:\n")
  350. print(x$oobPerClassError)
  351. }
  352. }else{
  353. if(!is.null(x$oobErr))
  354. cat(sprintf(" OOB error %0.2f%%;",utils::tail(x$oobErr,1)*100))
  355. if(!is.null(x$oobConfusionMatrix)){
  356. cat(" OOB confusion matrix:\n")
  357. print(x$oobConfusionMatrix)
  358. }
  359. }
  360. if(!is.null(x$oobScores) && any(is.na(x$oobScores)))
  361. cat(" Note: forest too small to provide good OOB approx.\n")
  362. return(invisible(x))
  363. }