-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathml-model-comparison-fish.qmd
More file actions
1155 lines (922 loc) · 43.6 KB
/
ml-model-comparison-fish.qmd
File metadata and controls
1155 lines (922 loc) · 43.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
---
title: "ML Model Comparison for Fish Threat in R"
author: "Elke Windschitl"
date: "2023-10-11"
format: html
editor: source
toc: true
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE,
warning = FALSE,
message = FALSE)
```
Description: In this qmd, I evaluate different supervised machine learning algorithms for predicting IUCN Red List status of fish based on ecological and morphological characteristics. These characteristics were retrieved from FishBase and joined with the IUCN data in a separate script.
## Introduction
Global human activity threatens many species with extinction. According to the International Union and Conservation of Nature (IUCN), “More than 41,000 species are threatened with extinction. That is still 28% of all assessed species.”^1^. Increased extinction and loss of biodiversity can have severe ecological, economic, and cultural impacts. Cardinale et al.’s deep dive into biodiversity and ecosystem services research conclude that biodiversity loss reduces ecological communities’ efficiency, stability, and productivity. Decreased productivity from ecosystem services can have a negative impact on ecosystem economics^2^. Additionally, cultures worldwide have strong ties to local flora and fauna, much of which now face extinction risk. Improving understanding of extinction risk is ecologically, economically, and culturally important.
The IUCN Red List classifies species into various categories based on how vulnerable they are to extinction. The Red List also has many species that are listed as "Data Deficient" or "Not Evaluated". Filling in these data gaps is extremely important when it comes to conservation. In marine species, evaluating these populations can prove challenging. It can be helpful to build off of existing knowledge to inform where evaluation resources should be spent. Here, I propose to build various machine learning models that predict binary Red List status of saltwater fish based on their ecological and morphological traits according to FishBase. I apply the most successful model to Red List Data Deficient and Not Evaluated species.
This work builds off of my previous work [Identifying Key Traits in Hawaiian Fish that Predict Risk of Extinction](https://elkewind.github.io/posts/2022-12-02-hawaiian-fish-analysis/). However, here I am looking at all fish listed on the IUCN Red List -- not just those in Hawaii -- and I am using a Tidymodels machine learning approach.
## The Data
For my analysis I use the IUCN Red List data accessed via the IUCN Red List API^1^ and package rredlist^3^. Consistent with Munstermann et al., living species listed as ‘Vulnerable’, ‘Endangered’, or ‘Critically Endangered’ were categorized as ‘Threatened’. Living species listed as ‘Least Concern’ and ‘Near Threatened’ were categorized as ‘Nonthreatened’^4^. I also chose to add 'Extinct in the Wild,' to the 'Threatened' category. Fully extinct species were not included. The IUCN Red List data are limited in that many marine species have not been listed yet or have been identified as too data deficient to be evaluated. The lack of data on elusive fish may introduce bias into the models.
Fish ecological data were accessed from FishBase^5^ via package rfishbase^6^. Different species in the FishBase data were originally described by different people, possibly leading to errors or biases. Measurement errors in length may be present, as there are various common ways to measure the length of a fish. The species recorded in FishBase may be biased towards fish with commercial value. Data were wrangled in R and formatted in a tidy data table with the following variables.
```{r output=TRUE, echo=FALSE}
library(tidyverse)
library(knitr)
# Read in data
read_csv("/Users/elkewindschitl/Documents/data-sci/fish_data.csv") %>%
filter(!is.na(IsOfConcern)) %>%
slice_head(n = 10) %>%
kable()
```
## Methods
To get started, there are several packages I will be using. *Tidyverse* packages help with further cleaning and preparing data. *Tidymodels* packages have almost all of what I need for the machine learning steps. *kknn* helps me build my knn model.
*knitr* is used to create kable tables. *baguette* is used in my bagging model. *doParallel* allows for parallel computing on my laptop. *vip* helps to identify variable importance.
```{r}
# Load libraries
library(tidyverse)
library(tidymodels)
library(kknn)
library(knitr)
library(baguette)
library(doParallel)
library(vip)
```
First I read in the data. These data were cleaned and joined in a separate script, but they will still need a bit of preprocessing. The outcome variable in this dataset is labeled IsOfConcern and indicates if the species is at risk or extinction (1) or not (0). I start out by exploring the data dimensions.
```{r}
# Read in data
fish_dat_full <- read_csv("/Users/elkewindschitl/Documents/data-sci/fish_data.csv")
fish_dat <- fish_dat_full %>%
filter(!is.na(IsOfConcern)) # remove columns that don't have outcome variable
# Explore some characteristics of the dataset
cols <- ncol(fish_dat)
rows <- nrow(fish_dat)
df_chars <- data.frame(
Metric = c("Number of Columns", "Number of Rows"),
Count = c(ncol(fish_dat), nrow(fish_dat)))
kable(df_chars,
col.names = c("", "Count"))
fish_dat %>%
group_by(IsOfConcern) %>%
count() %>%
kable(col.names = c("Species threat is of concern", "Count"))
```
### Data Prep
There are a lot on NA values in this dataset. I have a lot of columns already, so I can reduce that by removing columns that have a high proportion of NA values. Here I only keep columns where less than 20% of rows have NA values.
```{r}
# Calculate the proportion of NA values in each column
na_proportion <- colMeans(is.na(fish_dat), na.rm = TRUE)
#I want to remove rows with extreme NA counts (more than 20%)
# Define the threshold (20% or 0.20)
threshold <- 0.20
# Find columns with more than the threshold proportion of NA values
columns_meeting_threshold <- names(na_proportion[na_proportion <= threshold])
# Print the column names that meet the threshold
columns_meeting_threshold %>% kable(col.names = "Columns that are below NA threshold")
# Select for just those columns that meet my criteria
fish_short <- fish_dat %>%
select(all_of(columns_meeting_threshold))
```
There is still more to be done. I want to make sure numeric columns are numeric, and that character columns are treated as factors. I need to make sure my outcome variable is a factor as well. I have a small enough data frame that I am able to do this by looking at which columns are character and mutating them to be factors. I want unordered factors, and I need to remove columns that are identifiers rather than features. Because some of the algorithms I am working with here do not handle missing data well, I chose to remove all of the rows that had NA values. This did unfortunately cut down on the amount of data that I have to train and test the algorithms on.
```{r}
# Find character columns that need to be converted to factor
sapply(fish_short, class) %>% kable(col.names = c("Column", "Class"))
# List of character columns to convert to factors
character_columns_to_convert <- c("GenusSpecies", "BodyShapeI", "DemersPelag", "AirBreathing", "PriceCateg", "UsedforAquaculture", "Dangerous", "Electrogenic", "MainCommonName")
# Convert the specified character columns to factors
fish <- fish_short %>%
mutate(across(all_of(character_columns_to_convert), as.factor))
# If feature is a factor DON'T order, remove identifying columns
fish <- fish %>% mutate_if(is.ordered, .funs = factor, ordered = F) %>%
select(-GenusSpecies) %>%
select(-SpecCode) %>%
select(-MainCommonName)
# Make outcome factor
fish$IsOfConcern <- as.factor(fish$IsOfConcern)
# Remove rows with any remaining missing values
fish <- na.omit(fish)
# Check the new df
sapply(fish, class) %>% kable(col.names = c("Column", "Class"))
```
After my data are prepped, I need to split the data into training and testing data sets. I use a 70/30 split. I have unbalanced data, so I stratify by my outcome variable, IsOfConcern
```{r}
set.seed(123)
# Initial split of data, default 70/30
fish_split <- initial_split(fish, prop = 0.7, strata = IsOfConcern)
fish_train <- training(fish_split) # Training data
fish_test <- testing(fish_split) # Test data
```
### Preprocessing
I create a recipe for the preprocessing steps used. I use dummy columns to make all the factor (categorical) variables have their own column. I remove columns where there is no variation in the data. Then I normalize the numeric columns because the lasso and knn algorithms require normalization to avoid certain features dominating the model. I use the same preprocessing steps for all algorithms for adequate comparison.
```{r}
set.seed(123)
# Preprocess the data within the recipe
fish_recipe <- recipe(IsOfConcern ~ ., data = fish_train) %>%
step_dummy(all_factor(), -all_outcomes(), one_hot = TRUE) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric(), -all_outcomes())
# Check test and train dfs look as expected
prepped <- fish_recipe %>%
prep()
fish_baked_train <- bake(prepped, fish_train)
fish_baked_test <- bake(prepped, fish_test)
# Use below to check for NA values in the entire dataframes
#any(is.na(fish_baked_train))
#any(is.na(fish_baked_test))
```
### Dummy Classifier
Because my data are unbalanced with many more non-threatened species, if a model always chose non-threatened it would have a high accuracy. Of course, that is not very helpful when trying to predict which species *are or might be* threatened. Here I derive a dummy accuracy by calculating the accuracy of a model that always predicts non-threatened. This will serve as a baseline for if a model is performing well (better than the dummy) or not. However, because this dataset is so imbalanced, I will be using area under the curve (AUC) for model selection.
```{r}
# Calculate dummy classifier for baseline comparison
# Calculate the number of rows where IsOfConcern is 0
num_is_0 <- sum(fish_test$IsOfConcern == 0)
# Calculate the number of rows where IsOfConcern is not 0
num_is_not_0 <- nrow(fish_test) - num_is_0
# Calculate the accuracy of the dummy classifier (always predicting the majority class)
dummy <- num_is_0 / nrow(fish_test)
```
The dummy classifier accuracy is `r round(dummy, 3)`. This will serve as the baseline for other algorithms. Now I will proceed with building various models and training with the training data. I will be building Lasso, K-Nearest Neighbors, Decision Tree, Bagged Decision Tree, Random Forest, and Gradient Boosted Decision Tree models.
### Lasso for Classification
```{r}
set.seed(123)
# Set up k-fold cross validation with 10 folds. This can be used for all the algorithms
fish_cv = fish_train %>%
vfold_cv(v = 10,
strata = IsOfConcern)
# Set specifications
tune_l_spec <- logistic_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
# Define a workflow
wf_l <- workflow() %>%
add_model(tune_l_spec) %>%
add_recipe(fish_recipe)
# set grid
lambda_grid <- grid_regular(penalty(), levels = 50)
doParallel::registerDoParallel()
set.seed(123)
# Tune lasso model
lasso_grid <- wf_l %>%
tune_grid(
add_model(tune_l_spec),
resamples = fish_cv,
grid = lambda_grid
)
# Plot the mean accuracy and AUC at each penalty
lasso_grid %>%
collect_metrics() %>%
ggplot(aes(penalty, mean, color = .metric)) +
geom_errorbar(aes(ymin = mean - std_err,
ymax = mean + std_err),
alpha = 0.5) +
geom_line(size = 1.5) +
facet_wrap(~.metric,
scales = "free",
strip.position = "left",
nrow = 2, labeller = as_labeller(c(`accuracy` = "Accuracy",
`roc_auc` = "Area under ROC curve"))) +
scale_x_log10(name = "Penalty") +
scale_y_continuous(name = "") +
scale_color_manual(values = c("#4a6c75", "#57ba72")) +
theme_minimal() +
theme(
strip.placement = "outside",
legend.position = "none",
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Results of penalty tuning")
# View table
lasso_grid %>%
tune::show_best(metric = "roc_auc") %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the penalty parameter values.")
# Select the model with the highest auc
best_lasso <- lasso_grid %>%
select_best("roc_auc")
final_l_wf <- wf_l %>%
finalize_workflow(best_lasso)
# Perform a last fit to see how the model performs on the test data
final_lasso_fit <- last_fit(final_l_wf, fish_split)
# Collect metrics on the test data
tibble_lasso <- final_lasso_fit %>% collect_metrics()
tibble_lasso %>%
kable(caption = "Accuracy and area under ther receiver operator curve of the final fit.")
# Grab the model accuracy on the testing data
final_lasso_accuracy <- tibble_lasso %>%
filter(.metric == "accuracy") %>%
pull(.estimate)
final_lasso_auc <- tibble_lasso %>%
filter(.metric == "roc_auc") %>%
pull(.estimate)
# Bind predictions and original data
lasso_test_rs <- cbind(fish_test, final_lasso_fit$.predictions)[, -16]# Remove duplicate column
# Compute a confusion matrix
cm_lasso <- lasso_test_rs %>% yardstick::conf_mat(truth = IsOfConcern, estimate = .pred_class)
# Create a custom color palette
custom_palette <- scale_fill_gradient(
high = "#4a6c75",
low = "#d3e6eb"
)
# Create the confusion matrix heatmap plot
autoplot(cm_lasso, type = "heatmap") +
custom_palette + # Apply the custom color palette
theme(
axis.text.x = element_text(size = 12),
axis.text.y = element_text(size = 12),
axis.title = element_text(size = 14),
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Confusion matrix of lasso predictions on test data")
# Calculate rates of tru pos, false neg. etc. from the confusion matrix
TP_las <- cm_lasso$table[2, 2]
FP_las <- cm_lasso$table[2, 1]
TN_las <- cm_lasso$table[1, 1]
FN_las <- cm_lasso$table[1, 2]
TPR_las <- TP_las / (TP_las + FN_las) # True Positive Rate
FPR_las <- FP_las / (FP_las + TN_las) # False Positive Rate
TNR_las <- TN_las / (TN_las + FP_las) # True Negative Rate
FNR_las <- FN_las / (TP_las + FN_las) # False Negative Rate
# Create cm df to hold all false pos, etc. metrics
lasso_cm_vec <- c(TPR_las, FPR_las, TNR_las, FNR_las)
row_names <- c("True positive rate", "False positive rate", "True negative rate", "False negative rate")
cm_df <- bind_cols(Metric = row_names, Lasso = lasso_cm_vec)
```
The accuracy for the lasso model was `r round(final_lasso_accuracy, 3)` which is slightly better than our dummy classifier that had an accuracy of `r round(dummy, 3)`. This model had an AUC of `r round(final_lasso_auc, 3)`.
### K-Nearest Neighbors
```{r}
set.seed(123)
# Define the KNN model with tuning
knn_spec_tune <- nearest_neighbor(neighbors = tune()) %>% # tune k
set_mode("classification") %>%
set_engine("kknn")
# Define a new workflow
wf_knn_tune <- workflow() %>%
add_model(knn_spec_tune) %>%
add_recipe(fish_recipe)
# Fit the workflow on the predefined folds and hyperparameters
fit_knn_cv <- wf_knn_tune %>%
tune_grid(
fish_cv,
grid = data.frame(neighbors = c(1,5,10,15,seq(20,200,10))))
# Use autoplot() to examine how different parameter configurations relate to accuracy
autoplot(fit_knn_cv) +
theme_light() +
labs(
x = "Number of neighbors (K)",
title = "Results of neighbor tuning"
) +
theme(
legend.position = "none",
panel.background = element_blank(),
plot.background = element_blank()
) +
facet_wrap(
~.metric,
nrow = 2,
labeller = labeller(.metric = c("accuracy" = "Accuracy", "roc_auc" = "Area under ROC curve"))
)
# View table
fit_knn_cv %>%
tune::show_best(metric = "roc_auc") %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the number of neighbors parameter values.")
# Select the model with the highest auc
best_knn <- fit_knn_cv %>%
select_best("roc_auc")
# The final workflow for our KNN model
final_knn_wf <-
wf_knn_tune %>%
finalize_workflow(best_knn)
# Use last_fit() approach to apply model to test data
final_knn_fit <- last_fit(final_knn_wf, fish_split)
# Collect metrics on the test data
tibble_knn <- final_knn_fit %>% collect_metrics()
tibble_knn %>%
kable(caption = "Accuracy and area under the receiver operator curve of the final fit.")
# Store accuracy and AUC
final_knn_accuracy <- tibble_knn %>%
filter(.metric == "accuracy") %>%
pull(.estimate)
final_knn_auc <- tibble_knn %>%
filter(.metric == "roc_auc") %>%
pull(.estimate)
# Bind predictions and original data
knn_test_rs <- cbind(fish_test, final_knn_fit$.predictions)[, -16]
# Compute a confusion matrix
cm_knn <- knn_test_rs %>% yardstick::conf_mat(truth = IsOfConcern, estimate = .pred_class)
# Create the confusion matrix heatmap plot
autoplot(cm_knn, type = "heatmap") +
custom_palette +
theme(
axis.text.x = element_text(size = 12),
axis.text.y = element_text(size = 12),
axis.title = element_text(size = 14),
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Confusion matrix of knn predictions on test data")
# Calculate rates from the confusion matrix
TP_knn <- cm_knn$table[2, 2]
FP_knn <- cm_knn$table[2, 1]
TN_knn <- cm_knn$table[1, 1]
FN_knn <- cm_knn$table[1, 2]
TPR_knn <- TP_knn / (TP_knn + FN_knn) # True Positive Rate
FPR_knn <- FP_knn / (FP_knn + TN_knn) # False Positive Rate
TNR_knn <- TN_knn / (TN_knn + FP_knn) # True Negative Rate
FNR_knn <- FN_knn / (TP_knn + FN_knn) # False Negative Rate
# Add rates to cm df
knn_cm_vec <- c(TPR_knn, FPR_knn, TNR_knn, FNR_knn)
cm_df$KNN <- knn_cm_vec
```
The k-nearest neighbors model had nearly the same accuracy at predicting threat status than the dummy classifier. The accuracy of the model was `r round(final_knn_accuracy, 3)`. This model had an AUC of `r round(final_knn_auc, 3)` which is better than the lasso model.
### Decision Tree
```{r}
# Tell the model that we are tuning hyperparams
tree_spec_tune <- decision_tree(
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
# Set up grid
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(), levels = 5)
# Check grid
#tree_grid
# Define a workflow with the recipe and specification
wf_tree_tune <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(tree_spec_tune)
doParallel::registerDoParallel(cores = 3) #build trees in parallel
# Tune
tree_rs <- tune_grid(
wf_tree_tune,
IsOfConcern~.,
resamples = fish_cv,
grid = tree_grid,
metrics = metric_set(roc_auc)
)
# Use autoplot() to examine how different parameter configurations relate to auc
autoplot(tree_rs) +
theme_light() +
scale_color_manual(values = c("#4a6c75", "#57ba72", "#d596e0", "#e06d53", "#d6cf81")) +
labs(x = "Cost-complexity parameter",
y = "Area under the ROC curve",
title = "Results of tree tuning") +
theme(
plot.title = element_text(size = 16, hjust = 0.5),
panel.background = element_blank(),
plot.background = element_blank()
)
# View table
tree_rs %>%
tune::show_best(metric = "roc_auc") %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the tuned tree parameter values.")
# Finalize the model specs with the best hyperparameter result
final_tree <- finalize_model(tree_spec_tune, select_best(tree_rs))
# Final fit to test data
final_tree_fit <- last_fit(final_tree, IsOfConcern~., fish_split) # does training fit then final prediction as well
# Collect metrics from fit
tibble_tree <- final_tree_fit %>% collect_metrics()
tibble_tree %>% kable(caption = "Accuracy and area under ther receiver operator curve of the final fit.")
# Store accuracy and auc metrics
final_tree_accuracy <- tibble_tree %>%
filter(.metric == "accuracy") %>%
pull(.estimate)
final_tree_auc <- tibble_tree %>%
filter(.metric == "roc_auc") %>%
pull(.estimate)
# Bind predictions and original data
tree_test_rs <- cbind(fish_test, final_tree_fit$.predictions)[, -16]
# Compute a confusion matrix
cm_tree <- tree_test_rs %>% yardstick::conf_mat(truth = IsOfConcern, estimate = .pred_class)
# Create the confusion matrix heatmap plot
autoplot(cm_tree, type = "heatmap") +
custom_palette +
theme(
axis.text.x = element_text(size = 12),
axis.text.y = element_text(size = 12),
axis.title = element_text(size = 14),
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Confusion matrix of decision tree predictions on test data")
# Calculate rates from the confusion matrix
TP_tree <- cm_tree$table[2, 2]
FP_tree <- cm_tree$table[2, 1]
TN_tree <- cm_tree$table[1, 1]
FN_tree <- cm_tree$table[1, 2]
TPR_tree <- TP_tree / (TP_tree + FN_tree) # True Positive Rate
FPR_tree <- FP_tree / (FP_tree + TN_tree) # False Positive Rate
TNR_tree <- TN_tree / (TN_tree + FP_tree) # True Negative Rate
FNR_tree <- FN_tree / (TP_tree + FN_tree) # False Negative Rate
# Add rates to cm df
tree_cm_vec <- c(TPR_tree, FPR_tree, TNR_tree, FNR_tree)
cm_df$DecisionTree <- tree_cm_vec
```
The decision tree model had a higher accuracy at predicting threat status than the dummy classifier. The accuracy of the decision tree was `r round(final_tree_accuracy, 3)`. The AUC is `r round(final_tree_auc, 3)` which is lower than the lasso model and the knn model.
### Bagging
```{r}
set.seed(123)
# Set bagging tuning specifications
bag_spec <-
bag_tree(cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()) %>%
set_engine("rpart", times = 75) %>% # 25 ensemble members
set_mode("classification")
# Set up tuning grid
bag_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(), levels = 5)
# Check the grid space
#bag_grid
# Set up bagging workflow
wf_bag <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(bag_spec)
doParallel::registerDoParallel() #build trees in parallel
# Run tuning
bag_rs <- tune_grid(
wf_bag,
IsOfConcern~.,
resamples = fish_cv,
grid = bag_grid,
metrics = metric_set(roc_auc)
)
# Use autoplot() to examine how different parameter configurations relate to auc
autoplot(bag_rs) +
theme_light() +
scale_color_manual(values = c("#4a6c75", "#57ba72", "#d596e0", "#e06d53", "#d6cf81")) +
labs(x = "Cost-complexity parameter",
y = "Area under the ROC curve",
title = "Results of bagging tuning") +
theme(
plot.title = element_text(size = 16, hjust = 0.5),
panel.background = element_blank(),
plot.background = element_blank()
)
# View table
bag_rs %>%
tune::show_best() %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the tuned tree parameter values.")
# Finalize the model specs witht he best performing model
final_bag <- finalize_model(bag_spec, select_best(bag_rs))
# Perform a last fit of the model on the testing data
final_bag_fit <- last_fit(final_bag, IsOfConcern~., fish_split)
# Collect metrics from fit
tibble_bag <- final_bag_fit %>% collect_metrics()
tibble_bag %>%
kable(caption = "Accuracy and area under ther receiver operator curve of the final fit.")
# Store model accuracy and auc on testing data
final_bag_accuracy <- tibble_bag %>%
filter(.metric == "accuracy") %>%
pull(.estimate)
final_bag_auc <- tibble_bag %>%
filter(.metric == "roc_auc") %>%
pull(.estimate)
# Bind predictions and original data
bag_test_rs <- cbind(fish_test, final_bag_fit$.predictions)[, -16]
# Compute a confusion matrix
cm_bag <- bag_test_rs %>% yardstick::conf_mat(truth = IsOfConcern, estimate = .pred_class)
# Create the confusion matrix heatmap plot
autoplot(cm_bag, type = "heatmap") +
custom_palette + # Apply the custom color palette
theme(
axis.text.x = element_text(size = 12),
axis.text.y = element_text(size = 12),
axis.title = element_text(size = 14),
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Confusion matrix of bagging predictions on test data")
# Calculate rates from the confusion matrix
TP_bag <- cm_bag$table[2, 2]
FP_bag <- cm_bag$table[2, 1]
TN_bag <- cm_bag$table[1, 1]
FN_bag <- cm_bag$table[1, 2]
TPR_bag <- TP_bag / (TP_bag + FN_bag) # True Positive Rate
FPR_bag <- FP_bag / (FP_bag + TN_bag) # False Positive Rate
TNR_bag <- TN_bag / (TN_bag + FP_bag) # True Negative Rate
FNR_bag <- FN_bag / (TP_bag + FN_bag) # False Negative Rate
# Add rates to cm df
bag_cm_vec <- c(TPR_bag, FPR_bag, TNR_bag, FNR_bag)
cm_df$Bagging <- bag_cm_vec
```
The bagging model had a similar accuracy at predicting threat status as the decision tree and lasso model. The accuracy of the bagging was `r round(final_bag_accuracy, 3)`. The AUC is `r round(final_bag_auc, 3)` which is similar to the knn model.
### Random Forest
```{r}
set.seed(123)
# Set forest specifications
forest_spec <-
rand_forest(min_n = tune(),
mtry = tune(),
trees = tune()) %>%
set_engine("ranger") %>%
set_mode("classification")
# Set grid for tuning
forest_grid <- grid_regular(min_n(), mtry(c(1,44)), trees(), levels = 5)
# Check the grid space
#forest_grid
# Set the workflow with the recipe and forest specs
wf_forest <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(forest_spec)
doParallel::registerDoParallel()
# Perform tuning
forest_rs <- tune_grid(
wf_forest,
IsOfConcern~.,
resamples = fish_cv,
grid = forest_grid,
metrics = metric_set(roc_auc)
)
# Use autoplot() to examine how different parameter configurations relate to auc
autoplot(forest_rs) +
theme_light() +
scale_color_manual(values = c("#4a6c75", "#57ba72", "#d596e0", "#e06d53", "#d6cf81")) +
labs(x = "# Randomly selected predictors",
y = "Area under the ROC curve",
title = "Results of random forest tuning") +
theme(
plot.title = element_text(size = 16, hjust = 0.5),
panel.background = element_blank(),
plot.background = element_blank()
)
# View table
forest_rs %>%
tune::show_best() %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the tree and forest parameter values.")
# Finalize model with specs and best hyperparameters
final_forest <- finalize_model(forest_spec, select_best(forest_rs))
# Perform last fit
final_forest_fit <- last_fit(final_forest, IsOfConcern~., fish_split)
# Collect performance metrics
tibble_forest <- final_forest_fit %>% collect_metrics()
tibble_forest %>%
kable(caption = "Accuracy and area under ther receiver operator curve of the final fit.")
# Store accuracy and auc metrics
final_forest_accuracy <- tibble_forest %>%
filter(.metric == "accuracy") %>%
pull(.estimate)
final_forest_auc <- tibble_forest %>%
filter(.metric == "roc_auc") %>%
pull(.estimate)
# Bind predictions and original data
forest_test_rs <- cbind(fish_test, final_forest_fit$.predictions)[, -16]
# Compute a confusion matrix
cm_forest <- forest_test_rs %>% yardstick::conf_mat(truth = IsOfConcern, estimate = .pred_class)
# Create the confusion matrix heatmap plot
autoplot(cm_forest, type = "heatmap") +
custom_palette + # Apply the custom color palette
theme(
axis.text.x = element_text(size = 12),
axis.text.y = element_text(size = 12),
axis.title = element_text(size = 14),
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Confusion matrix of random forest predictions on test data")
# Calculate rates from the confusion matrix
TP_forest <- cm_forest$table[2, 2]
FP_forest <- cm_forest$table[2, 1]
TN_forest <- cm_forest$table[1, 1]
FN_forest <- cm_forest$table[1, 2]
TPR_forest <- TP_forest / (TP_forest + FN_forest) # True Positive Rate
FPR_forest <- FP_forest / (FP_forest + TN_forest) # False Positive Rate
TNR_forest <- TN_forest / (TN_forest + FP_forest) # True Negative Rate
FNR_forest <- FN_forest / (TP_forest + FN_forest) # False Negative Rate
# Add rates to cm df
forest_cm_vec <- c(TPR_forest, FPR_forest, TNR_forest, FNR_forest)
cm_df$RandomForest <- forest_cm_vec
```
The accuracy of the forest was `r round(final_forest_accuracy, 3)`. This is again similar to other models. This model had an auc of `r round(final_forest_auc, 3)`.
### Boosting
```{r}
# Tune learning rate first
# Set up specs for learning rate tuning
lr_spec <- parsnip::boost_tree(mode = "classification",
engine = "xgboost",
trees = 3000,
learn_rate = tune())
# Set up tuning grid
lr_grid <- expand.grid(learn_rate = seq(0.0001, 0.5, length.out = 50))
# Set up workflow
wf_lr_tune <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(lr_spec)
doParallel::registerDoParallel()
set.seed(123)
# Tune
lr_rs <- tune_grid(
wf_lr_tune,
IsOfConcern~.,
resamples = fish_cv,
grid = lr_grid
)
# Use autoplot() to examine how different parameter configurations relate to accuracy
autoplot(lr_rs) +
theme_light() +
labs(
x = "Learning rate",
title = "Results from learning rate tuning"
) +
theme(
legend.position = "none",
panel.background = element_blank(),
plot.background = element_blank()
) +
facet_wrap(
~.metric,
nrow = 2,
labeller = labeller(.metric = c("accuracy" = "Accuracy", "roc_auc" = "Area under ROC curve"))
)
# Identify best values from the tuning process
lr_rs %>%
tune::show_best(metric = "roc_auc") %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the learning rate parameter values.")
# Select best lr hyperparametes
best_learn <- lr_rs %>%
tune::select_best("roc_auc")
# Tune tree parameters next
# Create a new specification where setting the learning rate and tune the tree parameters
boost_tree_spec <- parsnip::boost_tree(mode = "classification",
engine = "xgboost",
trees = 3000,
learn_rate = best_learn$learn_rate,
min_n = tune(),
tree_depth = tune(),
loss_reduction = tune()
)
# Define parameters to be tuned
boost_params <- dials::parameters(
min_n(),
tree_depth(),
loss_reduction()
)
# Set up a tuning grid using grid_max_entropy() to get a representative sampling of the parameter space.
boost_tree_grid <- dials::grid_max_entropy(boost_params, size = 50)
# Set up workflow
wf_boost_tree_tune <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(boost_tree_spec)
set.seed(123)
doParallel::registerDoParallel()
# Tune
boost_tree_rs <- tune_grid(
wf_boost_tree_tune,
IsOfConcern~.,
resamples = fish_cv,
grid = boost_tree_grid
)
# Identify best values from the tuning process
boost_tree_rs %>%
tune::show_best(metric = "roc_auc") %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the tree parameter values.")
# Select best tree hyperparameters
boost_best_trees <- boost_tree_rs %>%
tune::select_best("roc_auc")
# Tune Stochastic Parameters
# Create another new specification where setting the learning rate and tree parameters and tune the stochastic parameters.
boost_stoc_spec <- parsnip::boost_tree(mode = "classification",
engine = "xgboost",
trees = 3000,
learn_rate = best_learn$learn_rate,
min_n = boost_best_trees$min_n,
tree_depth = boost_best_trees$tree_depth,
mtry = tune(),
loss_reduction = boost_best_trees$loss_reduction,
sample_size = tune(),
stop_iter = tune()
)
# Set up a tuning grid using grid_max_entropy() again.
# Define parameters to be tuned
boost_stoc_params <- dials::parameters(
finalize(mtry(),
select(fish_baked_train,-IsOfConcern)),
sample_size = sample_prop(c(.4, .9)),
stop_iter())
# Set up grid
boost_stoc_grid <- dials::grid_max_entropy(boost_stoc_params, size = 50)
# Set up workflow
wf_boost_stoc <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(boost_stoc_spec)
set.seed(123)
doParallel::registerDoParallel()
# Tune
boost_stoc_rs <- tune_grid(
wf_boost_stoc,
IsOfConcern~.,
resamples = fish_cv,
grid = boost_stoc_grid
)
# Identify best values from the tuning process
boost_stoc_rs %>%
tune::show_best(metric = "roc_auc") %>%
slice_head(n = 5) %>%
kable(caption = "Performance of the best models and the associated estimates for the stochastic parameter values.")
# Select best hyperparameters from tuning
boost_best_stoch <- boost_stoc_rs %>%
tune::select_best("roc_auc")
# Finalize workflow
# Assemble final workflow with all of the optimized parameters and do a final fit.
boost_final_spec <- parsnip::boost_tree(mode = "classification",
engine = "xgboost",
trees = 1000,
learn_rate = best_learn$learn_rate,
min_n = boost_best_trees$min_n,
tree_depth = boost_best_trees$tree_depth,
mtry = boost_best_stoch$mtry,
loss_reduction = boost_best_trees$loss_reduction,
stop_iter = boost_best_stoch$stop_iter,
sample_size = boost_best_stoch$sample_size
)
# Set up workflow
wf_boost_final <- workflow() %>%
add_recipe(fish_recipe) %>%
add_model(boost_final_spec)
# Fit to just training data (need for later)
final_simple_fit <- wf_boost_final %>%
fit(data = fish_train)
# Final fit
final_boost_fit <- last_fit(boost_final_spec, IsOfConcern~., fish_split)
# Store accuracy and auc metrics
tibble_boost <- final_boost_fit %>% collect_metrics()
tibble_boost %>%
kable(caption = "Accuracy and area under ther receiver operator curve of the final fit.")
final_boost_accuracy <- tibble_boost %>%
filter(.metric == "accuracy") %>%
pull(.estimate)
final_boost_auc <- tibble_boost %>%
filter(.metric == "roc_auc") %>%
pull(.estimate)
# Bind predictions and original data
boost_test_rs <- cbind(fish_test, final_boost_fit$.predictions)[, -16]
# Compute a confusion matrix
cm_boost <- boost_test_rs %>% yardstick::conf_mat(truth = IsOfConcern, estimate = .pred_class)
# Create the confusion matrix heatmap plot
autoplot(cm_boost, type = "heatmap") +
custom_palette + # Apply the custom color palette
theme(
axis.text.x = element_text(size = 12),
axis.text.y = element_text(size = 12),
axis.title = element_text(size = 14),
panel.background = element_blank(),
plot.background = element_blank()
) +
labs(title = "Confusion matrix of random boosted predictions on test data")
# Calculate rates from the confusion matrix
TP_boost <- cm_boost$table[2, 2]
FP_boost <- cm_boost$table[2, 1]
TN_boost <- cm_boost$table[1, 1]
FN_boost <- cm_boost$table[1, 2]
TPR_boost <- TP_boost / (TP_boost + FN_boost) # True Positive Rate
FPR_boost <- FP_boost / (FP_boost + TN_boost) # False Positive Rate
TNR_boost <- TN_boost / (TN_boost + FP_boost) # True Negative Rate
FNR_boost <- FN_boost / (TP_boost + FN_boost) # False Negative Rate
# Add rates to cm df
boost_cm_vec <- c(TPR_boost, FPR_boost, TNR_boost, FNR_boost)
cm_df$Boosting <- boost_cm_vec
```
The accuracy of the boosting was `r round(final_boost_accuracy, 3)`. This is also similar to other models. This model had an auc of `r round(final_boost_auc, 3)`, which is the best of all models by a small margin.
## Model Results
### Model selection
I want to compare accuracy and area under the curve of all models created.
```{r}
# Name models in vec
models <- c("Dummy", "Lasso", "KNN", "Decision Tree", "Bagging", "Random Forest", "Boosting")
# Create accuracy vec
accuracy <- c(dummy, final_lasso_accuracy, final_knn_accuracy, final_tree_accuracy, final_bag_accuracy, final_forest_accuracy, final_boost_accuracy)
# Make df
accuracy_df <- data.frame(models, accuracy)
# Create a factor with the desired order for models
accuracy_df$models <- factor(accuracy_df$models, levels = c("Dummy", "Lasso", "KNN", "Decision Tree", "Bagging", "Random Forest", "Boosting"))
# Create the plot
ggplot(accuracy_df, aes(x = models, y = accuracy)) +
geom_col(fill = "#4a6c75") +
theme_minimal() +
labs(title = "Accuracy was similar across all models",
x = "Model",
y = "Accuracy") +
geom_text(aes(label = round(accuracy, 3)), vjust = -0.5) +
theme(plot.background = element_blank(),