This repository was archived by the owner on May 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrandom_forest_master_script.R
More file actions
238 lines (192 loc) · 10.5 KB
/
random_forest_master_script.R
File metadata and controls
238 lines (192 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env Rscript
#####################################################
# Section 0: Setup: libraries, command-line arguments
#####################################################
if ("checkpoint" %in% installed.packages()){
library("checkpoint")
checkpoint("2021-01-01")
} else {
install.packages("checkpoint")
library("checkpoint")
checkpoint("2021-01-01")
}
###############
# 0.1 Libraries
###############
suppressPackageStartupMessages(library("ranger"))
suppressPackageStartupMessages(library("tidyverse"))
suppressPackageStartupMessages(library("assertthat"))
suppressPackageStartupMessages(library("optparse"))
suppressPackageStartupMessages(library("rlang"))
suppressPackageStartupMessages(library("gtools"))
######################
# 0.2 Custom functions
######################
source("custom_functions/import_and_validate_dataset.R")
source("custom_functions/create_train_test_set.R")
source("custom_functions/create_list_of_train_test_sets.R")
source("custom_functions/compute_cross_validated_kfold_rf.R")
source("custom_functions/get_ranger_model_accuracy.R")
source("custom_functions/create_list_of_permuted_dataframes.R")
source("custom_functions/permute_dataframe.R")
source("custom_functions/get_accuracies_from_permuted_cv_rf_results.R")
source("custom_functions/calculate_feature_mean_sd_importance_from_kfolds.R")
source("custom_functions/get_dataframe_of_permuted_variable_importances.R")
source("custom_functions/calculate_feature_pvalue.R")
####################################
# 0.2 Parsing command-line arguments
####################################
option_list = list(
make_option(c("-i", "--input_file"),
type = "character",
default = "data/breast-cancer.csv",
help="Path to .csv input file (n columns for X features, 1 column for sample class)",
metavar="filename"),
make_option(c("-o", "--outdir"),
type="character",
default="rf_results",
help="output directory where to store results [default= %default]",
metavar="character"),
make_option(c("-k", "--k_folds"),
type = "integer",
default = 5,
metavar = "integer",
help="Number of k-fold cross-validation to be performed (Usually between 5 and 10 folds) [default= %default]"),
make_option(c("-n", "--n_permutations"),
type = "integer",
default = 10,
metavar = "integer",
help="Number of permutations (Usually > 100 and up to 1000) [default= %default]"),
make_option(c("-t", "--n_trees"),
type = "integer",
default = 5000,
metavar = "integer",
help="Number of trees to be build in each individual random forest (Usually > 1000 and up to 10000) [default= %default]"),
make_option(c("-s", "--initial_seed"),
type = "integer",
default = 123,
metavar = "integer",
help="Initial seed used for the analysis [default= %default]"),
make_option(c("-p", "--threads"),
type = "integer",
default = 1,
metavar = "integer",
help="Number of threads to be used [default= %default]")
)
opt_parser = OptionParser(option_list=option_list,
description = "\n A program to perform a Random Forest analysis based on the ranger R package ",
epilogue = "Please visit https://cran.r-project.org/web//packages/ranger/ranger.pdf and https://github.com/imbs-hl/ranger for additional information");
args = parse_args(opt_parser)
dir.create(args$outdir)
##################################
# Section 1: reading input dataset
##################################
df <- import_and_validate_dataset(file_path = args$input_file)
cat("\n#####################################################################\n")
cat("\n Section 1: reading file successfully executed! \n")
cat("\n",nrow(df), "observations are present in your dataset. \n")
cat("\n",ncol(df) - 1, "features/variables will be considered. \n")
cat("\n The column named",colnames(df)[1], "will be used as your Y \n")
cat("\n#####################################################################\n")
#######################################################################################
# Section 2: Compute Random Forest k-fold cross-validation analysis on original dataset
# Returns k_folds model accuracies
# Returns k_folds variable importances
#######################################################################################
# pseudocode
# Step 1: create a list of k_fold train and test datasets
# Step 2a: run ranger RF on each of the f_fold train/test pair
# Step 2b: calculate the model classification accuracy for each k_fold iteration
# Step 2c: retrieve the variable importances and create a dataframe
# Step 1: create a list of k_fold train and test datasets
train_test_sets <- create_kfold_train_test_sets(mydata = df,
myseed = args$initial_seed,
.k_folds = args$k_folds)
# Step 2a: run ranger RF on each of the f_fold train/test pair
# Step 2b: calculate the model classification accuracy for each k_fold iteration
# Step 2c: retrieve the variable importances and create a dataframe
cv_rf_results <- compute_kfold_cv_rf(
list_of_train_test_sets = train_test_sets,
.num_trees = args$n_trees,
.num.threads = args$threads,
.importance = "impurity"
)
###################################################################################################
# Section 3: create a series of dataframes and compute random model accuracies and var. importances
# One random accuracy per permutation
# One variable importance per variable and per permutation
###################################################################################################
list_of_permuted_dfs <- create_list_of_permuted_dataframes(.df = df,
.initial_seed = args$initial_seed,
.n_permutations = args$n_permutations)
# creates a list of seeds for the n_permutations
seeds <- map_dbl(.x = seq_along(1:args$n_permutations),
.f = function(x){args$initial_seed + x})
train_test_permuted_sets <- map2(.x = list_of_permuted_dfs,
.y = seeds,
.f = function(x,y)
create_kfold_train_test_sets(
mydata = x,
myseed = y,
.k_folds = args$k_folds)
)
cv_rf_results_on_permuted_dfs <-
map(.x = train_test_permuted_sets,
.f = function(x)(compute_kfold_cv_rf(x,
.num_trees = args$n_trees,
.num.threads = args$threads)))
# Average random model accuracy
random_model_accuracies <-
tibble(acc = get_accuracies_from_permuted_cv_rf_results(cv_rf_results_on_permuted_dfs))
# Average variable importances
###################################################################################
# Section 4.: plot mean/sd model accuracy versus distribution of permuted accuracies
###################################################################################
# calculate mean/sd of model accuracy from original cv rf results
# add a confidence interval of 95% for the mean
cv_rf_original_model_results <- tibble(mean_acc = mean(cv_rf_results$accuracy),
sd_acc = sd(cv_rf_results$accuracy)) %>%
# Calculate confidence interval
# assumes that original model accuracies follow a normal distribution
# source: https://www.statology.org/confidence-interval-in-r/
mutate(margin = qt(0.95, df = args$k_folds) * sd_acc / sqrt(args$k_folds)) %>%
mutate(upper_conf_int =
ifelse(test = mean_acc + margin >= 100,
yes = 100,
no = mean_acc + margin),
lower_conf_int = ifelse(mean_acc - margin <= 0,yes = 0,no = mean_acc - margin)
)
# Plot original model accuracy versus distribution of random accuracies
p_model_accuracy <-
ggplot(random_model_accuracies, aes(x = acc)) +
geom_density() +
geom_vline(data = cv_rf_original_model_results,
aes(xintercept = mean_acc, colour = "Average")) +
geom_vline(data = cv_rf_original_model_results,
aes(xintercept = lower_conf_int, colour = "Lower")) +
geom_vline(data = cv_rf_original_model_results,
aes(xintercept = upper_conf_int, colour = "Upper")) +
scale_x_continuous(limits = c(0,100)) +
scale_colour_manual(name = "Original model accuracy\n(95% confidence interval)",
values = c(Average = "#1b9e77", Lower = "#d95f02", Upper = "#7570b3"))
p_model_accuracy
ggsave(filename = file.path(args$outdir,"model_accuracy.pdf"), plot = p_model_accuracy, width = 10, height = 7)
######################################################################################################
# Section 5: extract a table of significant features with their original mean and sd Gini index values
######################################################################################################
### Step 1: crunch numbers = calculate original CV mean/sd feature importance
original_var_importances <- calculate_feature_mean_sd_importance_from_kfolds(.var_imp_from_kfold_rf = cv_rf_results$var_importances)
### Step 2: crunch numbers = calculate permuted CV mean/sd feature importance for each permutation
# Calculate average and SD from the k-folds CV for each permutation
# First extract the variable importances from each permutation
# Then calcualte the mean/sd
# Finally select only the variable name and its mean gini index mean
permuted_var_importances <- get_dataframe_of_permuted_variable_importances(cv_rf_results_on_permuted_dfs)
## Step 3: calculate p-values
feature_pvalues <- calculate_feature_pvalue(original_var_importances, permuted_var_importances, .n_permutations = args$n_permutations)
### Step 4: write table with results
original_var_importances <- left_join(original_var_importances, feature_pvalues)
write.csv(x = original_var_importances,
file = file.path(args$outdir, "feature_gini_impurity_and_pvalues.csv"),
row.names = FALSE,
quote = FALSE)