From adf98c8421ced5563acee0a27676a0e6b77c963a Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 17:02:13 -0500 Subject: [PATCH 01/12] created predict_prob function for forest --- R/ccf.R | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/R/ccf.R b/R/ccf.R index 69b8ca0..832f5ae 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -154,6 +154,37 @@ predict.canonical_correlation_forest = function( return(treePredictions) } +#' @export +predict_prob.canonical_correlation_forest = function( + object, newdata, verbose = FALSE, ...) { + if (missing(newdata)) { + stop("Argument 'newdata' is missing.") + } + + ntree <- length(object$forest) + treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) + + if (length(unique(object$y)) > 2 { + stop("predict_prob currently only implemented for binary classifier. More than two classes detected") + } + + if (verbose) { + cat("calculating predictions\n") + } + # returns list of list + treePredictions = lapply(object$forest, predict, newdata) + # convert to matrix + treePredictions = do.call(cbind, treePredictions) + if (verbose) { + cat("Majority vote\n") + } + probs <- apply(treePredictions, 1, function(row) { + sum(table(row) == "1")/ntree #hardcoded to count predictions with class name "1" + }) + + return(probs) +} + #' Visualization of canonical correlation forest #' #' TODO: document From fe67251c331f8ef5c2fd0462324671b3a87c5a5a Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 17:03:28 -0500 Subject: [PATCH 02/12] fixed bracket --- R/ccf.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ccf.R b/R/ccf.R index 832f5ae..0028b29 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -164,7 +164,7 @@ predict_prob.canonical_correlation_forest = function( ntree <- length(object$forest) treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) - if (length(unique(object$y)) > 2 { + if (length(unique(object$y)) > 2 ){ stop("predict_prob currently only implemented for binary classifier. More than two classes detected") } From 90de956dfe74a9aadddc01bdd6960698a27a6f29 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 17:08:59 -0500 Subject: [PATCH 03/12] debugging --- R/ccf.R | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/R/ccf.R b/R/ccf.R index 0028b29..aff6c99 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -154,6 +154,36 @@ predict.canonical_correlation_forest = function( return(treePredictions) } +#' @export +predict_prob2.canonical_correlation_forest = function( + object, newdata, verbose = FALSE, ...) { + if (missing(newdata)) { + stop("Argument 'newdata' is missing.") + } + + ntree <- length(object$forest) + treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) + + if (length(unique(object$y)) > 2 ){ + stop("predict_prob currently only implemented for binary classifier. More than two classes detected") + } + + if (verbose) { + cat("calculating predictions\n") + } + # returns list of list + treePredictions = lapply(object$forest, predict, newdata) + # convert to matrix + treePredictions = do.call(cbind, treePredictions) + if (verbose) { + cat("Majority vote\n") + } + probs <- apply(treePredictions, 1, function(row) { + sum(table(row) == "1")/ntree #hardcoded to count predictions with class name "1" + }) + + return(probs) +} #' @export predict_prob.canonical_correlation_forest = function( object, newdata, verbose = FALSE, ...) { From 7eacb4c9c15018a5a065aadd1e01e345803466fc Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 17:20:06 -0500 Subject: [PATCH 04/12] modified method export --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/ccf.R | 31 +-------------------- man/canonical_correlation_tree.Rd | 13 +++++++-- man/ccf.Rd | 16 +++++++---- man/predict.canonical_correlation_forest.Rd | 3 +- man/spirals.Rd | 6 ++-- 7 files changed, 28 insertions(+), 44 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7304e27..867f3d5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,5 +26,5 @@ Suggests: MASS, pracma LazyData: TRUE -RoxygenNote: 6.0.1 +RoxygenNote: 7.1.1 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index 02248c3..f990745 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,7 @@ S3method(plot,canonical_correlation_forest) S3method(plot,canonical_correlation_tree) S3method(predict,canonical_correlation_forest) S3method(predict,canonical_correlation_tree) +S3method(predict_prob,canonical_correlation_forest) export(canonical_correlation_analysis) export(canonical_correlation_forest) export(canonical_correlation_tree) diff --git a/R/ccf.R b/R/ccf.R index aff6c99..eb5abb4 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -154,36 +154,7 @@ predict.canonical_correlation_forest = function( return(treePredictions) } -#' @export -predict_prob2.canonical_correlation_forest = function( - object, newdata, verbose = FALSE, ...) { - if (missing(newdata)) { - stop("Argument 'newdata' is missing.") - } - - ntree <- length(object$forest) - treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) - - if (length(unique(object$y)) > 2 ){ - stop("predict_prob currently only implemented for binary classifier. More than two classes detected") - } - - if (verbose) { - cat("calculating predictions\n") - } - # returns list of list - treePredictions = lapply(object$forest, predict, newdata) - # convert to matrix - treePredictions = do.call(cbind, treePredictions) - if (verbose) { - cat("Majority vote\n") - } - probs <- apply(treePredictions, 1, function(row) { - sum(table(row) == "1")/ntree #hardcoded to count predictions with class name "1" - }) - - return(probs) -} +#' @method predict_prob canonical_correlation_forest #' @export predict_prob.canonical_correlation_forest = function( object, newdata, verbose = FALSE, ...) { diff --git a/man/canonical_correlation_tree.Rd b/man/canonical_correlation_tree.Rd index d30eb67..4a93da8 100644 --- a/man/canonical_correlation_tree.Rd +++ b/man/canonical_correlation_tree.Rd @@ -4,9 +4,16 @@ \alias{canonical_correlation_tree} \title{Computes a canonical correlation tree} \usage{ -canonical_correlation_tree(X, Y, depth = 0, minPointsForSplit = 2, - maxDepthSplit = Inf, xVariationTolerance = 1e-10, - projectionBootstrap = FALSE, ancestralProbs = NULL) +canonical_correlation_tree( + X, + Y, + depth = 0, + minPointsForSplit = 2, + maxDepthSplit = Inf, + xVariationTolerance = 1e-10, + projectionBootstrap = FALSE, + ancestralProbs = NULL +) } \arguments{ \item{X}{Predictor matrix of size \eqn{n \times p} with \eqn{n} observations and \eqn{p} diff --git a/man/ccf.Rd b/man/ccf.Rd index fb0ec10..23cbdab 100644 --- a/man/ccf.Rd +++ b/man/ccf.Rd @@ -6,14 +6,18 @@ \alias{canonical_correlation_forest.formula} \title{Canonical correlation forest} \usage{ -canonical_correlation_forest(x, y = NULL, ntree = 200, verbose = FALSE, - ...) +canonical_correlation_forest(x, y = NULL, ntree = 200, verbose = FALSE, ...) -\method{canonical_correlation_forest}{default}(x, y = NULL, ntree = 200, - verbose = FALSE, projectionBootstrap = FALSE, ...) +\method{canonical_correlation_forest}{default}( + x, + y = NULL, + ntree = 200, + verbose = FALSE, + projectionBootstrap = FALSE, + ... +) -\method{canonical_correlation_forest}{formula}(x, y = NULL, ntree = 200, - verbose = FALSE, ...) +\method{canonical_correlation_forest}{formula}(x, y = NULL, ntree = 200, verbose = FALSE, ...) } \arguments{ \item{x}{Numeric matrix (n * p) with n observations of p variables} diff --git a/man/predict.canonical_correlation_forest.Rd b/man/predict.canonical_correlation_forest.Rd index 1ac83b5..853aa69 100644 --- a/man/predict.canonical_correlation_forest.Rd +++ b/man/predict.canonical_correlation_forest.Rd @@ -4,8 +4,7 @@ \alias{predict.canonical_correlation_forest} \title{Prediction from canonical correlation forest} \usage{ -\method{predict}{canonical_correlation_forest}(object, newdata, - verbose = FALSE, ...) +\method{predict}{canonical_correlation_forest}(object, newdata, verbose = FALSE, ...) } \arguments{ \item{object}{An object of class \code{canonical_correlation_forest}, as created diff --git a/man/spirals.Rd b/man/spirals.Rd index dd5cb64..95cbc8a 100644 --- a/man/spirals.Rd +++ b/man/spirals.Rd @@ -4,12 +4,14 @@ \name{spirals} \alias{spirals} \title{Spiral dataset} -\format{A data frame with 10000 rows and 3 variables: +\format{ +A data frame with 10000 rows and 3 variables: \describe{ \item{x}{numeric scalar: x-coordinate} \item{y}{numeric scalar: y-coordinate} \item{class}{integer: either 1,2 or 3} -}} +} +} \source{ Created by T. Rainforth, URL: \url{https://bitbucket.org/twgr/ccf/raw/49d5fce6fc006bc9a8949c7149fc9524535ce418/Datasets/spirals.csv} From cfe188e91551c176e9d11ab4c831ec82fcdc9779 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 18:52:09 -0500 Subject: [PATCH 05/12] added probability functionality (untested) --- DESCRIPTION | 2 +- NAMESPACE | 2 +- R/ccf.R | 50 ++++++++++++++------- R/cct.R | 10 ++++- man/predict.canonical_correlation_forest.Rd | 2 +- 5 files changed, 44 insertions(+), 22 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 867f3d5..b3f7f8d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,5 +26,5 @@ Suggests: MASS, pracma LazyData: TRUE -RoxygenNote: 7.1.1 +RoxygenNote: 7.1.2 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index f990745..d45fa6d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,12 +6,12 @@ S3method(plot,canonical_correlation_forest) S3method(plot,canonical_correlation_tree) S3method(predict,canonical_correlation_forest) S3method(predict,canonical_correlation_tree) -S3method(predict_prob,canonical_correlation_forest) export(canonical_correlation_analysis) export(canonical_correlation_forest) export(canonical_correlation_tree) export(get_missclassification_rate) export(plot_decision_surface) +export(predict_proba.canonical_correlation_forest) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,model.response) diff --git a/R/ccf.R b/R/ccf.R index eb5abb4..6057e58 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -128,11 +128,15 @@ canonical_correlation_forest.formula = function( #' canonical correlation trees. #' @export predict.canonical_correlation_forest = function( - object, newdata, verbose = FALSE, ...) { + object, newdata, verbose = FALSE, prob = FALSE, ...) { if (missing(newdata)) { stop("Argument 'newdata' is missing.") } + if (prob == TRUE && length(unique(object$y)) > 2 ){ + stop("predict with prob == TRUE currently only implemented for binary classifier. More than two classes detected") + } + ntree <- length(object$forest) treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) @@ -140,23 +144,38 @@ predict.canonical_correlation_forest = function( if (verbose) { cat("calculating predictions\n") } - # returns list of list - treePredictions = lapply(object$forest, predict, newdata) - # convert to matrix - treePredictions = do.call(cbind, treePredictions) - if (verbose) { - cat("Majority vote\n") - } - treePredictions <- apply(treePredictions, 1, function(row) { - names(which.max(table(row))) - }) - return(treePredictions) + if(prob){ + # returns list of list + treePredictions = lapply(object$forest, predict, newdata, prob = TRUE) + # convert to matrix + treePredictions = do.call(cbind, treePredictions) + + # todo: use built in function + probs <- apply(treePredictions, 1, function(row) { + mean(table(row)) + }) + + return(probs) + + }else{ + # returns list of list + treePredictions = lapply(object$forest, predict, newdata) + # convert to matrix + treePredictions = do.call(cbind, treePredictions) + if (verbose) { + cat("Majority vote\n") + } + treePredictions <- apply(treePredictions, 1, function(row) { + names(which.max(table(row))) + }) + + return(treePredictions) + } } -#' @method predict_prob canonical_correlation_forest #' @export -predict_prob.canonical_correlation_forest = function( +predict_proba.canonical_correlation_forest = function( object, newdata, verbose = FALSE, ...) { if (missing(newdata)) { stop("Argument 'newdata' is missing.") @@ -165,9 +184,6 @@ predict_prob.canonical_correlation_forest = function( ntree <- length(object$forest) treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) - if (length(unique(object$y)) > 2 ){ - stop("predict_prob currently only implemented for binary classifier. More than two classes detected") - } if (verbose) { cat("calculating predictions\n") diff --git a/R/cct.R b/R/cct.R index 91f3f23..e19ed6d 100644 --- a/R/cct.R +++ b/R/cct.R @@ -257,10 +257,16 @@ canonical_correlation_tree = function( } #' @export -predict.canonical_correlation_tree = function(object, newData, ...){ +predict.canonical_correlation_tree = function(object, newData, prob = FALSE, ...){ tree = object if (tree$isLeaf) { - return(tree$classIndex) + if(prob){ + #todo: make more flexible + #hardcoded to class1 right now + return(tree$trainingCounts[names(tree$trainingCounts) == "class1"]/sum(tree$trainingCounts)) + }else{ + return(tree$classIndex) + } } nr_of_features = length(tree$decisionProjection) # TODO use formula instead of all but last column diff --git a/man/predict.canonical_correlation_forest.Rd b/man/predict.canonical_correlation_forest.Rd index 853aa69..52b0d1c 100644 --- a/man/predict.canonical_correlation_forest.Rd +++ b/man/predict.canonical_correlation_forest.Rd @@ -4,7 +4,7 @@ \alias{predict.canonical_correlation_forest} \title{Prediction from canonical correlation forest} \usage{ -\method{predict}{canonical_correlation_forest}(object, newdata, verbose = FALSE, ...) +\method{predict}{canonical_correlation_forest}(object, newdata, verbose = FALSE, prob = FALSE, ...) } \arguments{ \item{object}{An object of class \code{canonical_correlation_forest}, as created From 447e9f8e3d8f3d3c00d4467d07445a07a1a65806 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 19:26:22 -0500 Subject: [PATCH 06/12] updated documentationfor predict function --- NAMESPACE | 1 - R/ccf.R | 28 +-------------------- man/predict.canonical_correlation_forest.Rd | 2 ++ 3 files changed, 3 insertions(+), 28 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index d45fa6d..02248c3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,7 +11,6 @@ export(canonical_correlation_forest) export(canonical_correlation_tree) export(get_missclassification_rate) export(plot_decision_surface) -export(predict_proba.canonical_correlation_forest) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,model.response) diff --git a/R/ccf.R b/R/ccf.R index 6057e58..4199344 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -124,6 +124,7 @@ canonical_correlation_forest.formula = function( #' @param newdata A data frame or a matrix containing the test data. #' @param verbose Optional argument to control if additional information are #' printed to the output. Default is \code{FALSE}. +#' @param prob boolean specifying whether to return probabilities #' @param ... Additional parameters passed on to prediction from individual #' canonical correlation trees. #' @export @@ -174,33 +175,6 @@ predict.canonical_correlation_forest = function( } } -#' @export -predict_proba.canonical_correlation_forest = function( - object, newdata, verbose = FALSE, ...) { - if (missing(newdata)) { - stop("Argument 'newdata' is missing.") - } - - ntree <- length(object$forest) - treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) - - - if (verbose) { - cat("calculating predictions\n") - } - # returns list of list - treePredictions = lapply(object$forest, predict, newdata) - # convert to matrix - treePredictions = do.call(cbind, treePredictions) - if (verbose) { - cat("Majority vote\n") - } - probs <- apply(treePredictions, 1, function(row) { - sum(table(row) == "1")/ntree #hardcoded to count predictions with class name "1" - }) - - return(probs) -} #' Visualization of canonical correlation forest #' diff --git a/man/predict.canonical_correlation_forest.Rd b/man/predict.canonical_correlation_forest.Rd index 52b0d1c..57ad0ea 100644 --- a/man/predict.canonical_correlation_forest.Rd +++ b/man/predict.canonical_correlation_forest.Rd @@ -15,6 +15,8 @@ by the function \code{\link{canonical_correlation_forest}}.} \item{verbose}{Optional argument to control if additional information are printed to the output. Default is \code{FALSE}.} +\item{prob}{boolean specifying whether to return probabilities} + \item{...}{Additional parameters passed on to prediction from individual canonical correlation trees.} } From 174fc00d3ccd6d2dc8f4c1ed8cdac5fd6cb0ac06 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 19:47:49 -0500 Subject: [PATCH 07/12] propagated 'prob' argument in recursive predict call --- R/cct.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/cct.R b/R/cct.R index e19ed6d..7fa2910 100644 --- a/R/cct.R +++ b/R/cct.R @@ -283,12 +283,12 @@ predict.canonical_correlation_tree = function(object, newData, prob = FALSE, ... if (any(lessThanPartPoint)) { currentNodeClasses[lessThanPartPoint, ] = predict.canonical_correlation_tree(tree$refLeftChild, - X[lessThanPartPoint, ,drop = FALSE]) #nolint + X[lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint } if (any(!lessThanPartPoint)) { currentNodeClasses[!lessThanPartPoint, ] = predict.canonical_correlation_tree(tree$refRightChild, - X[!lessThanPartPoint, ,drop = FALSE]) #nolint + X[!lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint } return(currentNodeClasses) } From 6acebbcaa80083546897fe2496730e5113ef7459 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 20:04:24 -0500 Subject: [PATCH 08/12] trying to make sure right function definition is being used --- R/cct.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/cct.R b/R/cct.R index 7fa2910..f15919e 100644 --- a/R/cct.R +++ b/R/cct.R @@ -282,12 +282,12 @@ predict.canonical_correlation_tree = function(object, newData, prob = FALSE, ... currentNodeClasses = matrix(nrow = max(nrow(X), 1)) if (any(lessThanPartPoint)) { currentNodeClasses[lessThanPartPoint, ] = - predict.canonical_correlation_tree(tree$refLeftChild, + predict(tree$refLeftChild, X[lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint } if (any(!lessThanPartPoint)) { currentNodeClasses[!lessThanPartPoint, ] = - predict.canonical_correlation_tree(tree$refRightChild, + predict(tree$refRightChild, X[!lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint } return(currentNodeClasses) From 29f4f017fb2478e3beaaa836a8a2cc38e8b244f5 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 20:16:43 -0500 Subject: [PATCH 09/12] using rowMeans now --- R/ccf.R | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/R/ccf.R b/R/ccf.R index 4199344..927f33d 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -146,24 +146,14 @@ predict.canonical_correlation_forest = function( cat("calculating predictions\n") } - if(prob){ - # returns list of list - treePredictions = lapply(object$forest, predict, newdata, prob = TRUE) - # convert to matrix - treePredictions = do.call(cbind, treePredictions) - - # todo: use built in function - probs <- apply(treePredictions, 1, function(row) { - mean(table(row)) - }) - - return(probs) + # returns list of list + treePredictions = lapply(object$forest, predict, newdata, prob = prob) + # convert to matrix + treePredictions = do.call(cbind, treePredictions) + if(prob){ + return(rowMeans(treePredictions)) }else{ - # returns list of list - treePredictions = lapply(object$forest, predict, newdata) - # convert to matrix - treePredictions = do.call(cbind, treePredictions) if (verbose) { cat("Majority vote\n") } From ab14a21d4031b0a1d540554988bcead21dab0592 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sat, 12 Mar 2022 20:26:30 -0500 Subject: [PATCH 10/12] reverting function scoping change --- R/cct.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/cct.R b/R/cct.R index f15919e..7fa2910 100644 --- a/R/cct.R +++ b/R/cct.R @@ -282,12 +282,12 @@ predict.canonical_correlation_tree = function(object, newData, prob = FALSE, ... currentNodeClasses = matrix(nrow = max(nrow(X), 1)) if (any(lessThanPartPoint)) { currentNodeClasses[lessThanPartPoint, ] = - predict(tree$refLeftChild, + predict.canonical_correlation_tree(tree$refLeftChild, X[lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint } if (any(!lessThanPartPoint)) { currentNodeClasses[!lessThanPartPoint, ] = - predict(tree$refRightChild, + predict.canonical_correlation_tree(tree$refRightChild, X[!lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint } return(currentNodeClasses) From 756ed2c25f6e2e7a61c7a1d12e657e399fa7bf5b Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sun, 13 Mar 2022 01:26:19 -0500 Subject: [PATCH 11/12] Made class to compute probabilities for more flexible --- R/ccf.R | 15 +++++++-------- R/cct.R | 12 +++++------- man/predict.canonical_correlation_forest.Rd | 4 ++-- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/R/ccf.R b/R/ccf.R index 927f33d..99d6b18 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -124,20 +124,16 @@ canonical_correlation_forest.formula = function( #' @param newdata A data frame or a matrix containing the test data. #' @param verbose Optional argument to control if additional information are #' printed to the output. Default is \code{FALSE}. -#' @param prob boolean specifying whether to return probabilities +#' @param probClass Optional argument specifying name of class to compute probabilities for #' @param ... Additional parameters passed on to prediction from individual #' canonical correlation trees. #' @export predict.canonical_correlation_forest = function( - object, newdata, verbose = FALSE, prob = FALSE, ...) { + object, newdata, verbose = FALSE, probClass = NULL, ...) { if (missing(newdata)) { stop("Argument 'newdata' is missing.") } - if (prob == TRUE && length(unique(object$y)) > 2 ){ - stop("predict with prob == TRUE currently only implemented for binary classifier. More than two classes detected") - } - ntree <- length(object$forest) treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) @@ -147,11 +143,14 @@ predict.canonical_correlation_forest = function( } # returns list of list - treePredictions = lapply(object$forest, predict, newdata, prob = prob) + treePredictions = lapply(object$forest, predict, newdata, probClass = probClass) # convert to matrix treePredictions = do.call(cbind, treePredictions) - if(prob){ + if(!is.null(probClass)){ + if (verbose) { + cat("Mean probability\n") + } return(rowMeans(treePredictions)) }else{ if (verbose) { diff --git a/R/cct.R b/R/cct.R index 7fa2910..ac6e005 100644 --- a/R/cct.R +++ b/R/cct.R @@ -257,13 +257,11 @@ canonical_correlation_tree = function( } #' @export -predict.canonical_correlation_tree = function(object, newData, prob = FALSE, ...){ +predict.canonical_correlation_tree = function(object, newData, probClass = NULL, ...){ tree = object if (tree$isLeaf) { - if(prob){ - #todo: make more flexible - #hardcoded to class1 right now - return(tree$trainingCounts[names(tree$trainingCounts) == "class1"]/sum(tree$trainingCounts)) + if(!is.null(probClass)){ + return(tree$trainingCounts[names(tree$trainingCounts) == probClass]/sum(tree$trainingCounts)) }else{ return(tree$classIndex) } @@ -283,12 +281,12 @@ predict.canonical_correlation_tree = function(object, newData, prob = FALSE, ... if (any(lessThanPartPoint)) { currentNodeClasses[lessThanPartPoint, ] = predict.canonical_correlation_tree(tree$refLeftChild, - X[lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint + X[lessThanPartPoint, ,drop = FALSE], probClass = probClass) #nolint } if (any(!lessThanPartPoint)) { currentNodeClasses[!lessThanPartPoint, ] = predict.canonical_correlation_tree(tree$refRightChild, - X[!lessThanPartPoint, ,drop = FALSE], prob = prob) #nolint + X[!lessThanPartPoint, ,drop = FALSE], probClass = probClass) #nolint } return(currentNodeClasses) } diff --git a/man/predict.canonical_correlation_forest.Rd b/man/predict.canonical_correlation_forest.Rd index 57ad0ea..3638212 100644 --- a/man/predict.canonical_correlation_forest.Rd +++ b/man/predict.canonical_correlation_forest.Rd @@ -4,7 +4,7 @@ \alias{predict.canonical_correlation_forest} \title{Prediction from canonical correlation forest} \usage{ -\method{predict}{canonical_correlation_forest}(object, newdata, verbose = FALSE, prob = FALSE, ...) +\method{predict}{canonical_correlation_forest}(object, newdata, verbose = FALSE, probClass = NULL, ...) } \arguments{ \item{object}{An object of class \code{canonical_correlation_forest}, as created @@ -15,7 +15,7 @@ by the function \code{\link{canonical_correlation_forest}}.} \item{verbose}{Optional argument to control if additional information are printed to the output. Default is \code{FALSE}.} -\item{prob}{boolean specifying whether to return probabilities} +\item{probClass}{Optional argument specifying name of class to compute probabilities for} \item{...}{Additional parameters passed on to prediction from individual canonical correlation trees.} From 487ae7d42792a46f1f5c6fb67a85a87572a2fa07 Mon Sep 17 00:00:00 2001 From: James Petrie Date: Sun, 13 Mar 2022 01:44:41 -0500 Subject: [PATCH 12/12] added argument check to make sure specified class is one of the options --- R/ccf.R | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/R/ccf.R b/R/ccf.R index 99d6b18..bb286a4 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -133,6 +133,12 @@ predict.canonical_correlation_forest = function( if (missing(newdata)) { stop("Argument 'newdata' is missing.") } + if(!is.null(probClass)){ + classNames = names(m1$forest[[1]]$trainingCounts) + if(!(probClass %in% classNames)){ + stop(paste0("Argument probClass = ", probClass, " is not in list of class names. Options are: ", paste(classNames, collapse = ', '))) + } + } ntree <- length(object$forest) treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree)