Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions R/calibration_by_factor.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' @import data.table
#'
#' @param data data.table object containing data to be plotted. (data table)
#' @param output_file filepath to which we write plot. (character)
#' @param output_file filepath to which we write plot (if NA or '', plot will not save). (character)
#' @param outcome_col_name column name of outcome for which we want the mean (by factor and risk quantile) plotted on y-axis. (character)
#' @param quantile_col_name column name of risk quantile to go on x-axis (could be something like y_hat_percentile). (character)
#' @param cluster_by_col_name column name of grouping for which we want to cluster standard errors (typically something like empi). (character)
Expand All @@ -27,7 +27,8 @@
#' vector of hex color codes matching the number of levels in plot_by_col_name. (character)
#' @param ymin minimum of y-axis. (numeric)
#' @param make_footnote whether or not to include a footnote on the plot (boolean)
#' @param return_plot whether or not to return created plot (boolean)'
#' @param include_legend whether or not to include a legend on the plot (boolean)
#' @param return_plot whether or not to return created plot (boolean)
#'
#' @return g optionally returns created plot
#'
Expand Down Expand Up @@ -55,8 +56,12 @@ plot_calibration_by_risk_quantile_by_factor <- function(
SE_style = 'ribbon',
ymin = 0,
make_footnote = TRUE,
include_legend = TRUE,
return_plot = FALSE) {

# set legend to appear at bottom of plot, or not appear at all
legend_position <- ifelse(include_legend == TRUE, 'bottom', 'none')

# compute means
data[, mean_obs_outcome := mean(get(outcome_col_name)), by = c(quantile_col_name, plot_by_col_name)]

Expand Down Expand Up @@ -102,14 +107,16 @@ plot_calibration_by_risk_quantile_by_factor <- function(
n_breaks <- ifelse(uniqueN(data[, get(quantile_col_name)]) < 5,
uniqueN(data[, get(quantile_col_name)]), 5)

plot_dt[, (quantile_col_name) := as.numeric(get(quantile_col_name))] # convert quantile column to numeric (if not already) for better plotting

# create calibration plot - with LINES for SE
if(!SE_line | (SE_line & SE_style == 'line')) {
calibration_plot <- ggplot(data = plot_dt, aes(y = mean_obs_outcome, x = get(quantile_col_name), color = factor(get(plot_by_col_name)))) +
geom_point() +
geom_line(aes(group = factor(get(plot_by_col_name)))) +
color_scale +
theme_bw() +
theme(legend.position = 'bottom') +
theme(legend.position = legend_position) +
xlab(xlabel) +
ylab(ylabel) +
# determine where the axes begin, make sure y axis is %
Expand All @@ -131,7 +138,7 @@ plot_calibration_by_risk_quantile_by_factor <- function(
}

# create calibration plot - with RIBBON for SE
if(SE_line == TRUE & SE_style == 'ribbon'){
if (SE_line == TRUE & SE_style == 'ribbon') {
calibration_plot <- ggplot(data = plot_dt, aes(
y = mean_obs_outcome,
x = get(quantile_col_name),
Expand All @@ -145,7 +152,7 @@ plot_calibration_by_risk_quantile_by_factor <- function(
xlab(xlabel) +
ylab(ylabel) +
theme_bw() +
theme(legend.position = 'bottom') +
theme(legend.position = legend_position) +
scale_y_continuous(labels = scales::percent, limits = c(ymin, NA))+
scale_x_continuous(breaks = pretty_breaks(n = n_breaks))
}
Expand All @@ -157,9 +164,11 @@ plot_calibration_by_risk_quantile_by_factor <- function(
g <- calibration_plot
}

# save
ggsave(output_file, g)
if(return_plot){
# save output if path is not NA or ''
if (!is.na(output_file) & output_file != '') {
ggsave(output_file, g)
}
if (return_plot) {
return(g)
}
}
Expand Down