diff --git a/R/calibration_by_factor.R b/R/calibration_by_factor.R index 72e0147..7dc3282 100644 --- a/R/calibration_by_factor.R +++ b/R/calibration_by_factor.R @@ -13,7 +13,7 @@ #' @import data.table #' #' @param data data.table object containing data to be plotted. (data table) -#' @param output_file filepath to which we write plot. (character) +#' @param output_file filepath to which we write plot (if NA or '', plot will not save). (character) #' @param outcome_col_name column name of outcome for which we want the mean (by factor and risk quantile) plotted on y-axis. (character) #' @param quantile_col_name column name of risk quantile to go on x-axis (could be something like y_hat_percentile). (character) #' @param cluster_by_col_name column name of grouping for which we want to cluster standard errors (typically something like empi). (character) @@ -27,7 +27,8 @@ #' vector of hex color codes matching the number of levels in plot_by_col_name. (character) #' @param ymin minimum of y-axis. (numeric) #' @param make_footnote whether or not to include a footnote on the plot (boolean) -#' @param return_plot whether or not to return created plot (boolean)' +#' @param include_legend whether or not to include a legend on the plot (boolean) +#' @param return_plot whether or not to return created plot (boolean) #' #' @return g optionally returns created plot #' @@ -55,8 +56,12 @@ plot_calibration_by_risk_quantile_by_factor <- function( SE_style = 'ribbon', ymin = 0, make_footnote = TRUE, + include_legend = TRUE, return_plot = FALSE) { + # set legend to appear at bottom of plot, or not appear at all + legend_position <- ifelse(include_legend == TRUE, 'bottom', 'none') + # compute means data[, mean_obs_outcome := mean(get(outcome_col_name)), by = c(quantile_col_name, plot_by_col_name)] @@ -102,6 +107,8 @@ plot_calibration_by_risk_quantile_by_factor <- function( n_breaks <- ifelse(uniqueN(data[, get(quantile_col_name)]) < 5, uniqueN(data[, get(quantile_col_name)]), 5) + plot_dt[, (quantile_col_name) := as.numeric(get(quantile_col_name))] # convert quantile column to numeric (if not already) for better plotting + # create calibration plot - with LINES for SE if(!SE_line | (SE_line & SE_style == 'line')) { calibration_plot <- ggplot(data = plot_dt, aes(y = mean_obs_outcome, x = get(quantile_col_name), color = factor(get(plot_by_col_name)))) + @@ -109,7 +116,7 @@ plot_calibration_by_risk_quantile_by_factor <- function( geom_line(aes(group = factor(get(plot_by_col_name)))) + color_scale + theme_bw() + - theme(legend.position = 'bottom') + + theme(legend.position = legend_position) + xlab(xlabel) + ylab(ylabel) + # determine where the axes begin, make sure y axis is % @@ -131,7 +138,7 @@ plot_calibration_by_risk_quantile_by_factor <- function( } # create calibration plot - with RIBBON for SE - if(SE_line == TRUE & SE_style == 'ribbon'){ + if (SE_line == TRUE & SE_style == 'ribbon') { calibration_plot <- ggplot(data = plot_dt, aes( y = mean_obs_outcome, x = get(quantile_col_name), @@ -145,7 +152,7 @@ plot_calibration_by_risk_quantile_by_factor <- function( xlab(xlabel) + ylab(ylabel) + theme_bw() + - theme(legend.position = 'bottom') + + theme(legend.position = legend_position) + scale_y_continuous(labels = scales::percent, limits = c(ymin, NA))+ scale_x_continuous(breaks = pretty_breaks(n = n_breaks)) } @@ -157,9 +164,11 @@ plot_calibration_by_risk_quantile_by_factor <- function( g <- calibration_plot } - # save - ggsave(output_file, g) - if(return_plot){ + # save output if path is not NA or '' + if (!is.na(output_file) & output_file != '') { + ggsave(output_file, g) + } + if (return_plot) { return(g) } }