From efa9ba17896f04ae80daf1bed001ae408195e33e Mon Sep 17 00:00:00 2001 From: Josh Buckner Date: Wed, 4 Mar 2026 11:06:54 -0500 Subject: [PATCH] Use meaningful group names in dbind() When a named list of ISMs is passed to dbind(), use the list names as group levels in the resulting BISM. When BISMs are in the mix, preserve their existing @groups levels. Unnamed entries fall back to position-based numeric indices, skipping any that collide with existing labels. Co-Authored-By: Claude Opus 4.6 --- R/InfinitySparseMatrix.R | 53 ++++++++++++++++------ tests/testthat/test.InfinitySparseMatrix.R | 40 ++++++++++++++++ 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/R/InfinitySparseMatrix.R b/R/InfinitySparseMatrix.R index fa7e498b..eb8bb08a 100644 --- a/R/InfinitySparseMatrix.R +++ b/R/InfinitySparseMatrix.R @@ -973,34 +973,61 @@ dbind <- function(..., force_unique_names = FALSE) { # Convert all matrices to ISMs if they aren't already. - mats <- lapply(mats, function(x) { + # Also build a parallel vector of group labels for each entry. + input_names <- names(mats) + converted <- lapply(seq_along(mats), function(i) { + x <- mats[[i]] + nm <- if (!is.null(input_names) && nzchar(input_names[i])) input_names[i] else NA_character_ + if (is(x, "BlockedInfinitySparseMatrix")) { - # Replace BISM with list of ISMs - findSubproblems(x) + # Replace BISM with list of ISMs; use its existing group level names + sp <- findSubproblems(x) + list(mats = sp, labels = names(sp)) } else if (inherits(x, "list")) { # If any entry in ... is a list, + inner_names <- names(x) # 1) Convert all entries in that list to ISM while keeping BISM as BISM x <- lapply(x, .as.ism_or_bism) - # 2) If we have any BISMs, split into list of ISMS - x <- lapply(x, function(y) { + # 2) If we have any BISMs, split into list of ISMs, preserving labels + inner_converted <- lapply(seq_along(x), function(j) { + y <- x[[j]] + inner_nm <- if (!is.null(inner_names) && nzchar(inner_names[j])) inner_names[j] else NA_character_ if (is(y, "BlockedInfinitySparseMatrix")) { - findSubproblems(y) + sp <- findSubproblems(y) + list(mats = sp, labels = names(sp)) } else { - y + list(mats = y, labels = inner_nm) } }) # 3) pull list of lists into list - flatten_list(x) + list(mats = flatten_list(lapply(inner_converted, `[[`, "mats")), + labels = unlist(lapply(inner_converted, `[[`, "labels"))) } else { # This will error appropriately if some element in `mats` cannot be # converted to an ISM. - .as.ism_or_bism(x) + list(mats = .as.ism_or_bism(x), labels = nm) } }) # If we were passed any BISMs, we have a list of lists of ISM, so flatten to a # single list. - mats <- flatten_list(mats) + mats <- flatten_list(lapply(converted, `[[`, "mats")) + group_labels <- unlist(lapply(converted, `[[`, "labels")) + + # Replace NA labels (from unnamed entries) with numeric indices based on + # their position, incrementing to avoid collisions with existing labels. + na_idx <- which(is.na(group_labels)) + if (length(na_idx) > 0) { + existing <- group_labels[!is.na(group_labels)] + for (i in na_idx) { + candidate <- i + while (as.character(candidate) %in% existing) { + candidate <- candidate + 1L + } + group_labels[i] <- as.character(candidate) + existing <- c(existing, as.character(candidate)) + } + } # new row and column positions are based on current, incrementing by number of # rows/columns in all previous matrices. @@ -1052,10 +1079,10 @@ dbind <- function(..., force_unique_names = FALSE) { newdim <- as.integer(c(sum(vapply(lapply(mats, methods::slot, "dimension"), "[", 1, 1)), sum(vapply(lapply(mats, methods::slot, "dimension"), "[", 1, 2)))) - # This needs to be much smarter, especially if any element is already a BISM - groups <- as.factor(rep(seq_along(mats), times = + groups <- factor(rep(group_labels, times = vapply(lapply(mats, slot, "colnames"), length, 1) + - vapply(lapply(mats, slot, "rownames"), length, 1))) + vapply(lapply(mats, slot, "rownames"), length, 1)), + levels = unique(group_labels)) names(groups) <- do.call(c, Map(c, cnameslist, rnameslist)) newdata <- do.call(c, mats) diff --git a/tests/testthat/test.InfinitySparseMatrix.R b/tests/testthat/test.InfinitySparseMatrix.R index 152abf66..e023340b 100644 --- a/tests/testthat/test.InfinitySparseMatrix.R +++ b/tests/testthat/test.InfinitySparseMatrix.R @@ -955,6 +955,46 @@ test_that("dbind", { expect_identical(bmix1, bmix3) }) +test_that("dbind uses meaningful group names", { + data(nuclearplants) + np <- nuclearplants + np$group <- as.numeric(cut(np$cap, breaks = c(0, 600, 825, 1000, 2000))) + + m1 <- match_on(pr ~ cost, data = np[np$group == 1,]) + m2 <- match_on(pr ~ cost, data = np[np$group == 2,]) + m3 <- match_on(pr ~ cost, data = np[np$group == 3,]) + m4 <- match_on(pr ~ cost, data = np[np$group == 4,]) + + # Named list of ISMs -> group levels match the list names + bm_named <- dbind(first = m1, second = m2) + expect_identical(levels(bm_named@groups), c("first", "second")) + + # Unnamed list -> levels are "1", "2", ... (backward compat) + bm_unnamed <- dbind(m1, m2) + expect_identical(levels(bm_unnamed@groups), c("1", "2")) + + # Mixed named list with a BISM entry -> ISM entries get list names, + # BISM entries get their original @groups levels + b1 <- match_on(pr ~ cost + strata(group), data = np[np$group < 3,]) + bm_mixed <- dbind(b1, extra = m3) + bism_levels <- levels(b1@groups) + expect_identical(levels(bm_mixed@groups), c(bism_levels, "extra")) + + # Partially named list -> named entries use names, unnamed fall back to index + bm_partial <- dbind(a = m1, m2, m3) + expect_identical(levels(bm_partial@groups), c("a", "2", "3")) + + # Unnamed fallback indices skip names that collide with existing labels + bm_collision <- dbind(m1, b1) + # m1 at position 1 collides with b1's groups "1"/"2", so gets "3" + # Levels preserve input order: m1's group first, then b1's groups + expect_identical(levels(bm_collision@groups), c("3", bism_levels)) + + # Named list passed as single argument + bm_named_list <- dbind(list(x = m1, y = m2)) + expect_identical(levels(bm_named_list@groups), c("x", "y")) +}) + test_that("dbind'ing a very large number of matrices", { data(nuclearplants)