Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions R/InfinitySparseMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -973,34 +973,61 @@ dbind <- function(..., force_unique_names = FALSE) {


# Convert all matrices to ISMs if they aren't already.
mats <- lapply(mats, function(x) {
# Also build a parallel vector of group labels for each entry.
input_names <- names(mats)
converted <- lapply(seq_along(mats), function(i) {
x <- mats[[i]]
nm <- if (!is.null(input_names) && nzchar(input_names[i])) input_names[i] else NA_character_
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slick. (This use of nzchar() was new to me, if clearly the way to go. Did claude help you get to this, @nullsatz ?)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, claude did that.


if (is(x, "BlockedInfinitySparseMatrix")) {
# Replace BISM with list of ISMs
findSubproblems(x)
# Replace BISM with list of ISMs; use its existing group level names
sp <- findSubproblems(x)
list(mats = sp, labels = names(sp))
} else if (inherits(x, "list")) {
# If any entry in ... is a list,
inner_names <- names(x)
# 1) Convert all entries in that list to ISM while keeping BISM as BISM
x <- lapply(x, .as.ism_or_bism)
# 2) If we have any BISMs, split into list of ISMS
x <- lapply(x, function(y) {
# 2) If we have any BISMs, split into list of ISMs, preserving labels
inner_converted <- lapply(seq_along(x), function(j) {
y <- x[[j]]
inner_nm <- if (!is.null(inner_names) && nzchar(inner_names[j])) inner_names[j] else NA_character_
if (is(y, "BlockedInfinitySparseMatrix")) {
findSubproblems(y)
sp <- findSubproblems(y)
list(mats = sp, labels = names(sp))
} else {
y
list(mats = y, labels = inner_nm)
}
})
# 3) pull list of lists into list
flatten_list(x)
list(mats = flatten_list(lapply(inner_converted, `[[`, "mats")),
labels = unlist(lapply(inner_converted, `[[`, "labels")))
} else {
# This will error appropriately if some element in `mats` cannot be
# converted to an ISM.
.as.ism_or_bism(x)
list(mats = .as.ism_or_bism(x), labels = nm)
}
})

# If we were passed any BISMs, we have a list of lists of ISM, so flatten to a
# single list.
mats <- flatten_list(mats)
mats <- flatten_list(lapply(converted, `[[`, "mats"))
group_labels <- unlist(lapply(converted, `[[`, "labels"))

# Replace NA labels (from unnamed entries) with numeric indices based on
# their position, incrementing to avoid collisions with existing labels.
na_idx <- which(is.na(group_labels))
if (length(na_idx) > 0) {
existing <- group_labels[!is.na(group_labels)]
for (i in na_idx) {
candidate <- i
while (as.character(candidate) %in% existing) {
candidate <- candidate + 1L
}
group_labels[i] <- as.character(candidate)
existing <- c(existing, as.character(candidate))
}
}

# new row and column positions are based on current, incrementing by number of
# rows/columns in all previous matrices.
Expand Down Expand Up @@ -1052,10 +1079,10 @@ dbind <- function(..., force_unique_names = FALSE) {
newdim <- as.integer(c(sum(vapply(lapply(mats, methods::slot, "dimension"), "[", 1, 1)),
sum(vapply(lapply(mats, methods::slot, "dimension"), "[", 1, 2))))

# This needs to be much smarter, especially if any element is already a BISM
groups <- as.factor(rep(seq_along(mats), times =
groups <- factor(rep(group_labels, times =
vapply(lapply(mats, slot, "colnames"), length, 1) +
vapply(lapply(mats, slot, "rownames"), length, 1)))
vapply(lapply(mats, slot, "rownames"), length, 1)),
levels = unique(group_labels))
names(groups) <- do.call(c, Map(c, cnameslist, rnameslist))

newdata <- do.call(c, mats)
Expand Down
40 changes: 40 additions & 0 deletions tests/testthat/test.InfinitySparseMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,46 @@ test_that("dbind", {
expect_identical(bmix1, bmix3)
})

test_that("dbind uses meaningful group names", {
data(nuclearplants)
np <- nuclearplants
np$group <- as.numeric(cut(np$cap, breaks = c(0, 600, 825, 1000, 2000)))

m1 <- match_on(pr ~ cost, data = np[np$group == 1,])
m2 <- match_on(pr ~ cost, data = np[np$group == 2,])
m3 <- match_on(pr ~ cost, data = np[np$group == 3,])
m4 <- match_on(pr ~ cost, data = np[np$group == 4,])

# Named list of ISMs -> group levels match the list names
bm_named <- dbind(first = m1, second = m2)
expect_identical(levels(bm_named@groups), c("first", "second"))

# Unnamed list -> levels are "1", "2", ... (backward compat)
bm_unnamed <- dbind(m1, m2)
expect_identical(levels(bm_unnamed@groups), c("1", "2"))

# Mixed named list with a BISM entry -> ISM entries get list names,
# BISM entries get their original @groups levels
b1 <- match_on(pr ~ cost + strata(group), data = np[np$group < 3,])
bm_mixed <- dbind(b1, extra = m3)
bism_levels <- levels(b1@groups)
expect_identical(levels(bm_mixed@groups), c(bism_levels, "extra"))

# Partially named list -> named entries use names, unnamed fall back to index
bm_partial <- dbind(a = m1, m2, m3)
expect_identical(levels(bm_partial@groups), c("a", "2", "3"))

# Unnamed fallback indices skip names that collide with existing labels
bm_collision <- dbind(m1, b1)
# m1 at position 1 collides with b1's groups "1"/"2", so gets "3"
# Levels preserve input order: m1's group first, then b1's groups
expect_identical(levels(bm_collision@groups), c("3", bism_levels))

# Named list passed as single argument
bm_named_list <- dbind(list(x = m1, y = m2))
expect_identical(levels(bm_named_list@groups), c("x", "y"))
})

test_that("dbind'ing a very large number of matrices", {
data(nuclearplants)

Expand Down