|
1 | 1 | #' Plot Method for cf_calibration Objects |
2 | 2 | #' |
3 | 3 | #' Creates a calibration plot showing predicted vs observed probabilities |
4 | | -#' under the counterfactual intervention. |
| 4 | +#' under the counterfactual intervention, with an optional histogram showing |
| 5 | +#' the distribution of predicted probabilities. |
5 | 6 | #' |
6 | 7 | #' @param x A `cf_calibration` object. |
7 | | -#' @param add_histogram Logical; add histogram of predictions (default: TRUE). |
8 | | -#' @param add_rug Logical; add rug plot (default: TRUE). |
| 8 | +#' @param add_histogram Logical; add histogram of predictions below the |
| 9 | +#' calibration curve (default: TRUE). |
| 10 | +#' @param add_rug Logical; add rug plot to show individual predictions |
| 11 | +#' (default: FALSE). |
9 | 12 | #' @param ... Additional arguments passed to plotting functions. |
10 | 13 | #' |
11 | 14 | #' @return A ggplot object (if ggplot2 available) or base R plot. |
12 | 15 | #' |
13 | 16 | #' @export |
14 | | -plot.cf_calibration <- function(x, add_histogram = TRUE, add_rug = TRUE, ...) { |
| 17 | +plot.cf_calibration <- function(x, add_histogram = TRUE, add_rug = FALSE, ...) { |
15 | 18 |
|
16 | 19 | estimator_label <- toupper(x$estimator) |
17 | 20 | treatment_label <- x$treatment_level |
18 | 21 |
|
19 | 22 | if (!requireNamespace("ggplot2", quietly = TRUE)) { |
20 | | - # Base R plot |
| 23 | + # Base R plot - simple version without histogram |
| 24 | + if (add_histogram) { |
| 25 | + old_par <- par(no.readonly = TRUE) |
| 26 | + on.exit(par(old_par)) |
| 27 | + layout(matrix(c(1, 2), nrow = 2), heights = c(3, 1)) |
| 28 | + par(mar = c(0, 4, 3, 2)) |
| 29 | + } |
| 30 | + |
21 | 31 | plot(x$predicted, x$observed, |
22 | | - type = "l", lwd = 2, |
23 | | - xlab = "Predicted probability", |
| 32 | + type = "l", lwd = 2, col = "#2E86AB", |
| 33 | + xlab = if (add_histogram) "" else "Predicted probability", |
24 | 34 | ylab = sprintf("Probability under A = %s (%s)", treatment_label, estimator_label), |
25 | 35 | main = "Counterfactual Calibration Curve", |
26 | | - xlim = c(0, 1), ylim = c(0, 1)) |
| 36 | + xlim = c(0, 1), ylim = c(0, 1), |
| 37 | + xaxt = if (add_histogram) "n" else "s") |
27 | 38 | abline(0, 1, lty = 2, col = "gray") |
| 39 | + |
| 40 | + if (add_histogram && !is.null(x$predictions_raw)) { |
| 41 | + par(mar = c(4, 4, 0, 2)) |
| 42 | + hist(x$predictions_raw, breaks = 30, col = "#2E86AB", border = "white", |
| 43 | + main = "", xlab = "Predicted probability", xlim = c(0, 1)) |
| 44 | + } |
| 45 | + |
28 | 46 | return(invisible(NULL)) |
29 | 47 | } |
30 | 48 |
|
31 | 49 | # ggplot2 version |
32 | | - df <- data.frame( |
| 50 | + df_curve <- data.frame( |
33 | 51 | predicted = x$predicted, |
34 | 52 | observed = x$observed |
35 | 53 | ) |
36 | 54 |
|
37 | | - p <- ggplot2::ggplot(df, ggplot2::aes(x = .data$predicted, y = .data$observed)) + |
| 55 | + # Main calibration curve |
| 56 | + p_cal <- ggplot2::ggplot(df_curve, ggplot2::aes(x = .data$predicted, y = .data$observed)) + |
38 | 57 | ggplot2::geom_line(linewidth = 1.2, color = "#2E86AB") + |
39 | 58 | ggplot2::geom_abline(slope = 1, intercept = 0, linetype = "dashed", |
40 | 59 | color = "gray50") + |
41 | 60 | ggplot2::labs( |
42 | | - x = "Predicted probability", |
43 | 61 | y = sprintf("Probability under A = %s (%s)", treatment_label, estimator_label), |
44 | 62 | title = "Counterfactual Calibration Curve", |
45 | 63 | subtitle = sprintf("ICI = %.3f, Emax = %.3f", x$ici, x$emax) |
46 | 64 | ) + |
47 | 65 | ggplot2::coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) + |
48 | 66 | ggplot2::theme_bw() |
49 | | - |
50 | | - return(p) |
| 67 | + |
| 68 | + # Add rug if requested |
| 69 | + if (add_rug && !is.null(x$predictions_raw)) { |
| 70 | + df_raw <- data.frame(predictions_raw = x$predictions_raw) |
| 71 | + p_cal <- p_cal + |
| 72 | + ggplot2::geom_rug(data = df_raw, |
| 73 | + ggplot2::aes(x = .data$predictions_raw, y = NULL), |
| 74 | + alpha = 0.3, color = "#2E86AB") |
| 75 | + } |
| 76 | + |
| 77 | + # Add histogram subplot if requested |
| 78 | + if (add_histogram && !is.null(x$predictions_raw)) { |
| 79 | + if (!requireNamespace("patchwork", quietly = TRUE)) { |
| 80 | + # Fall back to just the calibration plot with rug |
| 81 | + message("Install 'patchwork' package for histogram subplot. Showing rug plot instead.") |
| 82 | + df_raw <- data.frame(predictions_raw = x$predictions_raw) |
| 83 | + p_cal <- p_cal + |
| 84 | + ggplot2::geom_rug(data = df_raw, |
| 85 | + ggplot2::aes(x = .data$predictions_raw, y = NULL), |
| 86 | + alpha = 0.3, color = "#2E86AB") + |
| 87 | + ggplot2::labs(x = "Predicted probability") |
| 88 | + return(p_cal) |
| 89 | + } |
| 90 | + |
| 91 | + # Remove x-axis label from calibration plot |
| 92 | + p_cal <- p_cal + |
| 93 | + ggplot2::theme(axis.title.x = ggplot2::element_blank(), |
| 94 | + axis.text.x = ggplot2::element_blank(), |
| 95 | + axis.ticks.x = ggplot2::element_blank()) |
| 96 | + |
| 97 | + # Create histogram |
| 98 | + df_raw <- data.frame(predictions_raw = x$predictions_raw) |
| 99 | + p_hist <- ggplot2::ggplot(df_raw, ggplot2::aes(x = .data$predictions_raw)) + |
| 100 | + ggplot2::geom_histogram(bins = 30, fill = "#2E86AB", color = "white", alpha = 0.8) + |
| 101 | + ggplot2::labs(x = "Predicted probability", y = "Count") + |
| 102 | + ggplot2::coord_cartesian(xlim = c(0, 1)) + |
| 103 | + ggplot2::theme_bw() + |
| 104 | + ggplot2::theme(plot.margin = ggplot2::margin(t = 0, r = 5.5, b = 5.5, l = 5.5)) |
| 105 | + |
| 106 | + # Combine plots |
| 107 | + p_combined <- patchwork::wrap_plots(p_cal, p_hist, ncol = 1, heights = c(3, 1)) |
| 108 | + return(p_combined) |
| 109 | + } |
| 110 | + |
| 111 | + # No histogram - add x-axis label |
| 112 | + p_cal <- p_cal + ggplot2::labs(x = "Predicted probability") |
| 113 | + return(p_cal) |
51 | 114 | } |
52 | 115 |
|
53 | 116 |
|
|
0 commit comments