-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path04_ML_Interp.Rmd
More file actions
299 lines (224 loc) · 14.2 KB
/
04_ML_Interp.Rmd
File metadata and controls
299 lines (224 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
---
title: "04_InterpModel"
output: html_document
editor_options:
chunk_output_type: inline
---
```{r setup, include=FALSE}
library(tidyverse)
library(xgboost)
library(Metrics)
library(gridExtra)
library(feather)
library(lubridate)
knitr::opts_chunk$set(echo = TRUE)
```
## R Markdown
This is an R Markdown document. Markdown is a simple formatting syntax for authoring HTML, PDF, and MS Word documents. For more details on using R Markdown see <http://rmarkdown.rstudio.com>.
When you click the **Knit** button a document will be generated that includes both content as well as the output of any embedded R code chunks within the document. You can embed an R code chunk like this:
```{r cars}
monthly <- read_feather('data/out/Trends_Pick_Monthly.feather')
#annual <- read_feather('data/out/Trends_Pick_Full.feather')
sampAtts <- read_csv('data/out/AW_GEE_Exports/SampleSiteCharacteristics_Medians.csv') %>%
dplyr::select(-BurnArea, -BurnYear) %>%
rename(sampID = SampID) %>%
left_join(read_csv('data/out/AW_GEE_Exports/SampleSiteCharacteristics_Means.csv') %>%
dplyr::select(BurnArea, BurnYear, sampID = SampID)) %>%
rename(PF_Probability = PFProb, Permeability = PermP, SoilOrganicsCarbon = soc, FireExtent = BurnArea)
climTrends <- read_feather('data/out/climTrendsMonthly.feather') %>%
dplyr::select(-p.value) %>%
pivot_wider(names_from = name, values_from = slope)
terrain <- read_csv('data/out/AW_GEE_Exports/SampleTerrain.csv') %>%
rename(sampID = SampID, elevation = mean, elev_sd = sdstdDev)
target <- 'group'
threshold <- "Permanent"
type <- 'frac_area'
#month_target <- 6
names(sampChar)
feats <- c('PF_Probability','Permeability', 'Porosity', 'clay', 'SoilOrganicsCarbon', 'FireExtent', 'sand', 'swrad', 'p_e', 'temp', 'precip', 'evap', 'elevation', 'elev_sd')
#sampChar <- sampChar %>% left_join(climTrends %>% filter(month == month_target)) %>% left_join(terrain) #%>% left_join(sampMeans %>% dplyr::select(-filtID))
```
## First we'll train our model.
Here, I let the model choose which predictors are most important. This works for xgboost a little better than for random forests because its not as susceptible to misplacing importance due to collinearity. We'll look at feature importance using a couple different metrics all the same.
Below, we train our model until the validation rmse stops improving. This avoids overfitting. We check the RMSE using both 5 fold cross validation and using holdout test data to get two different measures of model performance.
#### Model target = `r target`
Note: Ultimately, we should probably decide on only one modeling approach. All the interpretability metrics I use here are model agnostic, meaning we can apply them to random forest, xgboost, etc.
```{r}
shap_out <- tibble(ID = integer(), name = character(), value = numeric(), SHAP = numeric(), month = integer())
month_target <- 6
for(month_target in c(6,7,8,9)){
sampChar <- sampAtts %>% left_join(climTrends %>% filter(month == month_target)) %>% left_join(terrain) #%>% left_join(sampMeans %>% dplyr::select(-filtID))
modelInput <- monthly %>%
filter(Threshold == threshold, name == type, month==month_target) %>%
left_join(sampChar) %>%
mutate(group = ifelse(slope > 0, 1, 0)) %>%
dplyr::select(all_of(c('sampID', feats, target, 'p.value'))) %>%
rename(id = sampID) %>%
#filter(abs(slope) < 0.1) %>%
na.omit()
set.seed(2423)
if(target == 'group'){
modelInput = modelInput %>% filter(p.value < 0.05)
}
train <- modelInput %>% sample_frac(.7)
test <- modelInput %>% filter(!id %in% train$id)
dtrain <- xgb.DMatrix(data = as.matrix(train[,feats]), label = train[,target][[1]])
dtest <- xgb.DMatrix(data = as.matrix(test[,feats]), label = test[,target][[1]])
if(target == 'slope'){
params <- list(booster = "gbtree", objective = "reg:squarederror", eta=0.01, gamma=0, max_depth=4, min_child_weight=1, subsample=1, colsample_bytree=1)
xgb.naive <- xgb.train(params = params, data = dtrain, nrounds = 1000, watchlist = list(train = dtrain, val = dtest), print_every_n = 25, early_stopping_rounds = 10, maximize = F)
preds <- test %>% mutate(predicted = predict(xgb.naive, dtest))
}
#params <- list(booster = "gbtree", objective = "multi:softmax", eta=0.01, gamma=0, max_depth=2, min_child_weight=1, subsample=1, colsample_bytree=1, num_class = 3)
## Do a quick cv to check ideal number of folds
#xgbcv <- xgb.cv(params = params, data = dtrain, nrounds = 1000, nfold = 5, showsd = T, stratified = T, print_every_n = 25, early_stopping_round = 10, maximize = F)
if(target == 'group'){
params <- list(booster = "gbtree", objective = "binary:logistic", eta=0.01, gamma=0, max_depth=2, min_child_weight=1, subsample=1, colsample_bytree=1)
xgb.naive <- xgb.train(params = params, data = dtrain, nrounds = 2000, watchlist = list(train = dtrain, val = dtest), print_every_n =100, early_stopping_rounds = 10)
preds <- test %>% mutate(predicted = as.numeric(predict(xgb.naive, dtest)>0.5))
err <- mean(as.numeric(preds$predicted > 0.5) != preds$group)
auc <- Metrics::auc(preds$group, preds$predicted)
}
check <- xgb.plot.shap(data = as.matrix(train[,feats]), top_n = 14, n_col = 3, model = xgb.naive, plot = F)
df <- as_tibble(check$data) %>% mutate(ID = row_number()) %>% pivot_longer(-ID) %>%
left_join(as_tibble(check$shap_contrib) %>% mutate(ID = row_number()) %>% pivot_longer(-ID, values_to = 'SHAP')) %>%
mutate(month = month_target)
shap_out <- shap_out %>% bind_rows(df)
#p1<- ggplot(df, aes(x = value, y = SHAP)) + geom_point(size = .1) + geom_smooth(se = F) + facet_wrap(~name, scales = 'free') + ggtitle(month_target)
p2 <- xgb.ggplot.shap.summary(data = as.matrix(train[,feats]), top_n = 9, model = xgb.naive) + ggtitle(paste(month_target, ' error:',round(err,3),' auc: ',round(auc,3)))
#print(p1)
print(p2)
#importance_matrix <- xgb.importance(colnames(train[,feats]), model = xgb.naive)
#xgb.plot.importance(importance_matrix, rel_to_first = TRUE, xlab = "Relative importance")
}
```
```{r}
vars = c('elevation', 'SoilOrganicsCarbon', 'clay', 'PF_Probability', 'swrad', 'temp', 'precip', 'FireExtent')
vars <- c('clay', 'elevation', 'FireExtent', 'PF_Probability', 'SoilOrganicsCarbon')
renamer <- c('Clay Content', 'Elevation', 'Percent Burn Area', 'Permafrost Extent', 'Soil Organic Carbon')
renamer <- tibble(name = vars, name_new = renamer)
shap_out %>%
filter(name %in% vars) %>%
left_join(renamer) %>%
group_by(name, month) %>%
mutate(value_scaled = scale(value)) %>%
ungroup() %>%
mutate(month = month(month, label = T)) %>%
filter(value_scaled < 3) %>%
ggplot(aes(x = value, y = SHAP, color = month)) +
geom_point(size = .1, alpha=.3) +
scale_color_viridis_d(option = 'mako', end = .7) + geom_smooth(se = F) +
facet_wrap(~name_new, scales = 'free') +
theme_bw() +
theme(legend.position = c(.8,.25)) +
labs(x= 'Feature Value', y = 'Impact to Model Prediction (SHAP)', color = 'Month')
ggsave('figures/SHAPMonthly.png', width = 6.5, height = 4, units = 'in')
```
#### We want our CV RMSE to roughly equal our holdout RMSE, if holdout is much lower we might be overfitting.
#### `r paste0('Naive Model Hold-out RMSE = ', xgb.naive$best_score)`
#### `r paste0('CV RMSE = ', xgbcv$evaluation_log$test_rmse_mean[xgbcv$best_iteration])`
Note: We still need to think about is how we want to portray our validation. The non-linear aspects of dominant wavelength makes it a little tricky.
```{r, eval = F}
paste0('Naive Model Hold-out RMSE = ', xgb.naive$best_score)
#paste0('CV RMSE = ', xgbcv$evaluation_log$test_rmse_mean[xgbcv$best_iteration])
preds <- test %>% mutate(predicted = predict(xgb.naive, dtest))
ggplot(preds, aes_string(x = target, y = 'predicted')) +
geom_hex() +
scale_fill_viridis_c(trans = 'log10') +
#coord_cartesian(xlim = c(-.02,.02), ylim = c(-0.02,0.02)) +
geom_abline(color = 'red') +
ggpmisc::stat_poly_eq(formula = y~x,
aes(label = paste(..eq.label.., ..rr.label.., sep = "~~~")),
parse = TRUE)
```
## Now lets look at some interpretability metrics, first feature importance and Accumulated Local Effects (ALE) plots
Here, feature importance is permutation feature importance, or the increase in error when the values of a predictor are randomly shuffled (i.e. when we add significant noise to a given predictor).
```{r}
library(iml)
pred <- function(model, newdata){
predict(model, xgb.DMatrix(as.matrix(newdata)))
}
predictor <- Predictor$new(xgb.naive, data = train[,feats], y = train[,target][[1]], predict.function = pred)
featureImp <- FeatureImp$new(predictor, loss = 'mse')
plot(featureImp)
effs <- FeatureEffects$new(predictor)
plot(effs, fixed_y = F)
FE <- do.call('rbind', effs$results)
```
ALE plots describe the average influence of a predictor on the final prediction along a localized window of values. These plots are simple to interpret, fast to calculate, and aren't impacted collinearity in the predictor space. If the ALE value is above zero, it means the feature has a positive impace on model predictions at the given value along the x-axis. If it's below 0, it has a negative impact. The distributions above each plot represent the distibution of values (5-95th percentile) we have in our training data (not sure why the plot alignment is a little off, this can be fixed down the road).
For a summary of ALE plots, see https://christophm.github.io/interpretable-ml-book/ale.html.
```{r}
alePlotter <- function(feature){
perc5 = quantile(train[[feature]],.01, na.rm =T)
perc95 = quantile(train[[feature]],0.99, na.rm = T)
p1 <- ggplot(train, aes_string(x = feature)) + geom_density(adjust = 4, fill = 'grey70') +
xlim(perc5,perc95) +
theme_classic() +
theme(axis.text = element_text(color = 'transparent'),
axis.title = element_blank(),
#axis.title.y = element_text(color = 'transparent'),
axis.ticks = element_blank(),
axis.line = element_line(color = 'transparent'),
plot.margin = margin(0,-1,-1,-1))
#if(feature == 'dWL'){p1 = p1 + labs(tag = 'b)')}
p2 <- FE %>% filter(.feature == feature, .borders >= perc5, .borders <= perc95) %>%
ggplot(.,aes(x= .borders, y = .value)) +
geom_line() +
geom_point() +
xlim(perc5,perc95) +
geom_hline(aes(yintercept = 0), color = 'red')+
facet_wrap(~.feature) +
theme_bw() +
theme(axis.title = element_blank(),
plot.margin = margin(-3,0,0,0))
#if(feature == 'dWL'){p2 = p2 + labs(tag = 'b)')}
arrangeGrob(p1,p2, nrow = 2, heights = c(.3,1))
}
p1 <- alePlotter(featureImp$results$feature[1])
p2 <- alePlotter(featureImp$results$feature[2])
p3 <- alePlotter(featureImp$results$feature[3])
p4 <- alePlotter(featureImp$results$feature[4])
p5 <- alePlotter(featureImp$results$feature[5])
p6 <- alePlotter(featureImp$results$feature[6])
p7 <- alePlotter(featureImp$results$feature[7])
grid.arrange(p1,p2,p3,p4,p5,p6,p7, nrow = 4, left = 'Accumulated Local Effect (nm)', bottom = 'Feature Value')
```
### Surrogate Trees can help us identify what variable splits and thresholds might be most important.
These are similar to the results that Xiao has presented, except that they're based on the *predictions* of our machine learning model and not the actual *observed* values in our dataset. This means that it's telling us important splits/thresholds in the model itself. Here I limit them to a depth of 2, but this is adjustable.
```{r}
tree <- TreeSurrogate$new(predictor, maxdepth = 2)
plot(tree$tree)
```
### Finally, we'll look at SHAP (Shapely Additive exPlanations).
These are similar to ALE plots, but show the distribution of feature effects across all observations instead of just averaged across a small window. Additionally, they can tell us the feature contributions to any *individual* prediction. The methods are a little more complicated, but a simplified explanation is that SHAP builds iterative local models at each of our observations. For more info, see https://christophm.github.io/interpretable-ml-book/shap.html
First: we'll look at overall SHAP results for our nine most important features. For SHAP, importance is calculated as those predictors that have the highest cumulative impact across the local models.
```{r}
check <- xgb.plot.shap(data = as.matrix(train[,feats]), top_n = 12, n_col = 3, model = xgb.naive, plot = F)
df <- as_tibble(check$data) %>% mutate(ID = row_number()) %>% pivot_longer(-ID) %>%
left_join(as_tibble(check$shap_contrib) %>% mutate(ID = row_number()) %>% pivot_longer(-ID, values_to = 'SHAP'))
ggplot(df, aes(x = value, y = SHAP)) + geom_point(size = .1) + geom_smooth(se = F) + facet_wrap(~name, scales = 'free')
check <- df %>% group_by(name) %>%
summarise(valueMean = mean(value),
valueSD = sd(value),
shapMean = mean(SHAP),
shapSD = sd(SHAP))
df %>%
left_join(check) %>%
filter(value > valueMean - 4*valueSD & value < valueMean + 4*valueSD,
SHAP > shapMean - 4*shapSD & SHAP < shapMean + 4*shapSD) %>%
ggplot(aes(x = value, y = SHAP)) + geom_point(size = .1) + geom_smooth(se = F) + facet_wrap(~name, scales = 'free')
xgb.ggplot.shap.summary(data = as.matrix(train[,feats]), top_n = 9, model = xgb.naive)
```
Then, we'll look at the SHAP summary plot for all predictors. Here, features are ordered by their SHAP importance. Each point on the summary plot is a Shapley value for a feature in one of our observations. The x-axis is it's Shapley value (it's contribution to the model prediction). The color represents the relative value of the feature from low to high.
```{r}
library(SHAPforxgboost)
shap_long <- shap.prep(xgb_model = xgb.naive, X_train = as.matrix(train[,feats]))
shap.plot.summary(shap_long, x_bound = .005)
shap.plot.dependence(data_long = shap_long, x = 'temp', y = 'temp', color_feature = 'PF_Probability')
shap.plot.dependence(data_long = shap_long, x = 'clay', y = 'clay', color_feature = 'PF_Probability')
```
With SHAP, we can take any individual observation and see how each input feature contributes to the final prediction as shown below for the first observation in our training dataset. Phi, along the x axis, denotes the impact of each feature to the final prediction. This allows us to look at any individual lake and see how our model is predicting its color.
```{r}
shap <- Shapley$new(predictor, x.interest = train[,feats][1,])
plot(shap)
```