Skip to content

Commit f84aa2d

Browse files
committed
Add histogram subplot to calibration plots
- Store raw predictions in cf_calibration and tr_calibration results - Add histogram subplot below calibration curve using patchwork - Supports both ggplot2 (with patchwork) and base R graphics - add_histogram parameter controls display (default: TRUE) - Falls back gracefully if patchwork not installed
1 parent 4f305e8 commit f84aa2d

5 files changed

Lines changed: 149 additions & 36 deletions

File tree

R/cf_calibration.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ cf_calibration <- function(predictions,
185185
result <- list(
186186
predicted = predicted,
187187
observed = observed,
188+
predictions_raw = pred_use, # Raw predictions for histogram
188189
weights = weights,
189190
smoother = smoother,
190191
estimator = estimator,

R/plot.R

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,116 @@
11
#' Plot Method for cf_calibration Objects
22
#'
33
#' Creates a calibration plot showing predicted vs observed probabilities
4-
#' under the counterfactual intervention.
4+
#' under the counterfactual intervention, with an optional histogram showing
5+
#' the distribution of predicted probabilities.
56
#'
67
#' @param x A `cf_calibration` object.
7-
#' @param add_histogram Logical; add histogram of predictions (default: TRUE).
8-
#' @param add_rug Logical; add rug plot (default: TRUE).
8+
#' @param add_histogram Logical; add histogram of predictions below the
9+
#' calibration curve (default: TRUE).
10+
#' @param add_rug Logical; add rug plot to show individual predictions
11+
#' (default: FALSE).
912
#' @param ... Additional arguments passed to plotting functions.
1013
#'
1114
#' @return A ggplot object (if ggplot2 available) or base R plot.
1215
#'
1316
#' @export
14-
plot.cf_calibration <- function(x, add_histogram = TRUE, add_rug = TRUE, ...) {
17+
plot.cf_calibration <- function(x, add_histogram = TRUE, add_rug = FALSE, ...) {
1518

1619
estimator_label <- toupper(x$estimator)
1720
treatment_label <- x$treatment_level
1821

1922
if (!requireNamespace("ggplot2", quietly = TRUE)) {
20-
# Base R plot
23+
# Base R plot - simple version without histogram
24+
if (add_histogram) {
25+
old_par <- par(no.readonly = TRUE)
26+
on.exit(par(old_par))
27+
layout(matrix(c(1, 2), nrow = 2), heights = c(3, 1))
28+
par(mar = c(0, 4, 3, 2))
29+
}
30+
2131
plot(x$predicted, x$observed,
22-
type = "l", lwd = 2,
23-
xlab = "Predicted probability",
32+
type = "l", lwd = 2, col = "#2E86AB",
33+
xlab = if (add_histogram) "" else "Predicted probability",
2434
ylab = sprintf("Probability under A = %s (%s)", treatment_label, estimator_label),
2535
main = "Counterfactual Calibration Curve",
26-
xlim = c(0, 1), ylim = c(0, 1))
36+
xlim = c(0, 1), ylim = c(0, 1),
37+
xaxt = if (add_histogram) "n" else "s")
2738
abline(0, 1, lty = 2, col = "gray")
39+
40+
if (add_histogram && !is.null(x$predictions_raw)) {
41+
par(mar = c(4, 4, 0, 2))
42+
hist(x$predictions_raw, breaks = 30, col = "#2E86AB", border = "white",
43+
main = "", xlab = "Predicted probability", xlim = c(0, 1))
44+
}
45+
2846
return(invisible(NULL))
2947
}
3048

3149
# ggplot2 version
32-
df <- data.frame(
50+
df_curve <- data.frame(
3351
predicted = x$predicted,
3452
observed = x$observed
3553
)
3654

37-
p <- ggplot2::ggplot(df, ggplot2::aes(x = .data$predicted, y = .data$observed)) +
55+
# Main calibration curve
56+
p_cal <- ggplot2::ggplot(df_curve, ggplot2::aes(x = .data$predicted, y = .data$observed)) +
3857
ggplot2::geom_line(linewidth = 1.2, color = "#2E86AB") +
3958
ggplot2::geom_abline(slope = 1, intercept = 0, linetype = "dashed",
4059
color = "gray50") +
4160
ggplot2::labs(
42-
x = "Predicted probability",
4361
y = sprintf("Probability under A = %s (%s)", treatment_label, estimator_label),
4462
title = "Counterfactual Calibration Curve",
4563
subtitle = sprintf("ICI = %.3f, Emax = %.3f", x$ici, x$emax)
4664
) +
4765
ggplot2::coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
4866
ggplot2::theme_bw()
49-
50-
return(p)
67+
68+
# Add rug if requested
69+
if (add_rug && !is.null(x$predictions_raw)) {
70+
df_raw <- data.frame(predictions_raw = x$predictions_raw)
71+
p_cal <- p_cal +
72+
ggplot2::geom_rug(data = df_raw,
73+
ggplot2::aes(x = .data$predictions_raw, y = NULL),
74+
alpha = 0.3, color = "#2E86AB")
75+
}
76+
77+
# Add histogram subplot if requested
78+
if (add_histogram && !is.null(x$predictions_raw)) {
79+
if (!requireNamespace("patchwork", quietly = TRUE)) {
80+
# Fall back to just the calibration plot with rug
81+
message("Install 'patchwork' package for histogram subplot. Showing rug plot instead.")
82+
df_raw <- data.frame(predictions_raw = x$predictions_raw)
83+
p_cal <- p_cal +
84+
ggplot2::geom_rug(data = df_raw,
85+
ggplot2::aes(x = .data$predictions_raw, y = NULL),
86+
alpha = 0.3, color = "#2E86AB") +
87+
ggplot2::labs(x = "Predicted probability")
88+
return(p_cal)
89+
}
90+
91+
# Remove x-axis label from calibration plot
92+
p_cal <- p_cal +
93+
ggplot2::theme(axis.title.x = ggplot2::element_blank(),
94+
axis.text.x = ggplot2::element_blank(),
95+
axis.ticks.x = ggplot2::element_blank())
96+
97+
# Create histogram
98+
df_raw <- data.frame(predictions_raw = x$predictions_raw)
99+
p_hist <- ggplot2::ggplot(df_raw, ggplot2::aes(x = .data$predictions_raw)) +
100+
ggplot2::geom_histogram(bins = 30, fill = "#2E86AB", color = "white", alpha = 0.8) +
101+
ggplot2::labs(x = "Predicted probability", y = "Count") +
102+
ggplot2::coord_cartesian(xlim = c(0, 1)) +
103+
ggplot2::theme_bw() +
104+
ggplot2::theme(plot.margin = ggplot2::margin(t = 0, r = 5.5, b = 5.5, l = 5.5))
105+
106+
# Combine plots
107+
p_combined <- patchwork::wrap_plots(p_cal, p_hist, ncol = 1, heights = c(3, 1))
108+
return(p_combined)
109+
}
110+
111+
# No histogram - add x-axis label
112+
p_cal <- p_cal + ggplot2::labs(x = "Predicted probability")
113+
return(p_cal)
51114
}
52115

53116

R/tr_calibration.R

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ tr_calibration <- function(predictions,
361361
list(
362362
predicted = calib_result$predicted,
363363
observed = calib_result$observed,
364+
predictions_raw = predictions, # Raw predictions for histogram
364365
weights = calib_result$weights,
365366
smoother = smoother,
366367
ici = calib_result$ici,
@@ -782,26 +783,44 @@ tr_calibration <- function(predictions,
782783
#' @param x A `tr_calibration` object.
783784
#' @param add_reference Logical; add 45-degree reference line (default: TRUE).
784785
#' @param show_metrics Logical; show calibration metrics on plot (default: TRUE).
786+
#' @param add_histogram Logical; add histogram of predictions below the
787+
#' calibration curve (default: TRUE).
785788
#' @param ... Additional arguments passed to plotting functions.
786789
#'
787790
#' @return A ggplot object (if ggplot2 available) or base R plot.
788791
#'
789792
#' @export
790-
plot.tr_calibration <- function(x, add_reference = TRUE, show_metrics = TRUE, ...) {
793+
plot.tr_calibration <- function(x, add_reference = TRUE, show_metrics = TRUE,
794+
add_histogram = TRUE, ...) {
791795

792796
estimator_label <- toupper(x$estimator)
793797

794798
if (!requireNamespace("ggplot2", quietly = TRUE)) {
795-
# Base R plot
799+
# Base R plot - simple version without histogram
800+
if (add_histogram && !is.null(x$predictions_raw)) {
801+
old_par <- par(no.readonly = TRUE)
802+
on.exit(par(old_par))
803+
layout(matrix(c(1, 2), nrow = 2), heights = c(3, 1))
804+
par(mar = c(0, 4, 3, 2))
805+
}
806+
796807
plot(x$predicted, x$observed,
797-
type = "l", lwd = 2,
798-
xlab = "Predicted probability",
808+
type = "l", lwd = 2, col = "#2E86AB",
809+
xlab = if (add_histogram) "" else "Predicted probability",
799810
ylab = sprintf("Probability in Target Population (%s)", estimator_label),
800811
main = "Calibration Curve in the Target Population",
801-
xlim = c(0, 1), ylim = c(0, 1))
812+
xlim = c(0, 1), ylim = c(0, 1),
813+
xaxt = if (add_histogram) "n" else "s")
802814
if (add_reference) {
803815
abline(0, 1, lty = 2, col = "gray")
804816
}
817+
818+
if (add_histogram && !is.null(x$predictions_raw)) {
819+
par(mar = c(4, 4, 0, 2))
820+
hist(x$predictions_raw, breaks = 30, col = "#2E86AB", border = "white",
821+
main = "", xlab = "Predicted probability", xlim = c(0, 1))
822+
}
823+
805824
return(invisible(NULL))
806825
}
807826

@@ -817,10 +836,9 @@ plot.tr_calibration <- function(x, add_reference = TRUE, show_metrics = TRUE, ..
817836
x$ici, x$e50, x$emax)
818837
}
819838

820-
p <- ggplot2::ggplot(df, ggplot2::aes(x = .data$predicted, y = .data$observed)) +
839+
p_cal <- ggplot2::ggplot(df, ggplot2::aes(x = .data$predicted, y = .data$observed)) +
821840
ggplot2::geom_line(linewidth = 1.2, color = "#2E86AB") +
822841
ggplot2::labs(
823-
x = "Predicted probability",
824842
y = sprintf("Probability in Target Population (%s)", estimator_label),
825843
title = "Calibration Curve in the Target Population",
826844
subtitle = subtitle_text
@@ -829,9 +847,40 @@ plot.tr_calibration <- function(x, add_reference = TRUE, show_metrics = TRUE, ..
829847
ggplot2::theme_bw()
830848

831849
if (add_reference) {
832-
p <- p + ggplot2::geom_abline(slope = 1, intercept = 0,
833-
linetype = "dashed", color = "gray50")
850+
p_cal <- p_cal + ggplot2::geom_abline(slope = 1, intercept = 0,
851+
linetype = "dashed", color = "gray50")
834852
}
835853

836-
return(p)
854+
# Add histogram subplot if requested
855+
if (add_histogram && !is.null(x$predictions_raw)) {
856+
if (!requireNamespace("patchwork", quietly = TRUE)) {
857+
# Fall back to just the calibration plot
858+
message("Install 'patchwork' package for histogram subplot.")
859+
p_cal <- p_cal + ggplot2::labs(x = "Predicted probability")
860+
return(p_cal)
861+
}
862+
863+
# Remove x-axis label from calibration plot
864+
p_cal <- p_cal +
865+
ggplot2::theme(axis.title.x = ggplot2::element_blank(),
866+
axis.text.x = ggplot2::element_blank(),
867+
axis.ticks.x = ggplot2::element_blank())
868+
869+
# Create histogram
870+
df_raw <- data.frame(predictions_raw = x$predictions_raw)
871+
p_hist <- ggplot2::ggplot(df_raw, ggplot2::aes(x = .data$predictions_raw)) +
872+
ggplot2::geom_histogram(bins = 30, fill = "#2E86AB", color = "white", alpha = 0.8) +
873+
ggplot2::labs(x = "Predicted probability", y = "Count") +
874+
ggplot2::coord_cartesian(xlim = c(0, 1)) +
875+
ggplot2::theme_bw() +
876+
ggplot2::theme(plot.margin = ggplot2::margin(t = 0, r = 5.5, b = 5.5, l = 5.5))
877+
878+
# Combine plots
879+
p_combined <- patchwork::wrap_plots(p_cal, p_hist, ncol = 1, heights = c(3, 1))
880+
return(p_combined)
881+
}
882+
883+
# No histogram - add x-axis label
884+
p_cal <- p_cal + ggplot2::labs(x = "Predicted probability")
885+
return(p_cal)
837886
}

vignettes/introduction.html

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

vignettes/transportability.html

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)