diff --git a/DESCRIPTION b/DESCRIPTION index 8db9ab4..eeb67f7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -83,6 +83,7 @@ Collate: 'module-program-of-thought.R' 'module-rag.R' 'module-react.R' + 'module-rlm.R' 'module-wrapper.R' 'module.R' 'optimize.R' @@ -96,6 +97,7 @@ Collate: 'pipeline.R' 'r-code-runner.R' 'ragnar.R' + 'rlm-tools.R' 'run.R' 'signature-parser.R' 'signature-transforms.R' diff --git a/NAMESPACE b/NAMESPACE index 7f4f85c..f5061a2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -182,6 +182,7 @@ export(refine) export(register_dsprrr_engine) export(register_dsprrr_tool) export(restore_module_config) +export(rlm_module) export(run) export(run_async) export(run_dataset) diff --git a/R/module-rlm.R b/R/module-rlm.R new file mode 100644 index 0000000..ae8213b --- /dev/null +++ b/R/module-rlm.R @@ -0,0 +1,1033 @@ +#' Recursive Language Model (RLM) Module +#' +#' @description +#' A module that transforms context from "input" to "environment", enabling LLMs +#' to programmatically explore large contexts through a REPL interface rather +#' than embedding them in prompts. +#' +#' @details +#' Instead of `llm(prompt, context=huge_document)`, RLM stores context as R +#' variables that the LLM can peek, slice, search, and recursively query. +#' +#' The execution flow is: +#' 1. Context is made available as variables in an R execution environment +#' 2. LLM generates R code to explore and analyze the context +#' 3. Code is executed in an isolated subprocess via RCodeRunner +#' 4. Results are fed back to the LLM for the next iteration +#' 5. Process continues until SUBMIT() is called or max_iterations reached +#' 6. If max_iterations reached without SUBMIT(), fallback extraction is used +#' +#' Available REPL tools: +#' - `SUBMIT(answer)`: Terminate and return final answer +#' - `peek(var, start, end)`: View a slice of a variable (default: first 1000 chars) +#' - `search(var, pattern)`: Regex search in variable +#' - `rlm_query(query, context_slice)`: Recursive LLM call (requires sub_lm) +#' - `rlm_query_batch(queries, slices)`: Batched recursive calls +#' +#' Security: Code execution requires explicit opt-in via a runner parameter. +#' The runner provides subprocess isolation but is NOT a security sandbox. +#' For untrusted inputs, use OS-level sandboxing (containers, AppArmor). +#' +#' @examples +#' \dontrun{ +#' # Create a runner (required for code execution) +#' runner <- r_code_runner(timeout = 30) +#' +#' # Create an RLM module for exploring large documents +#' rlm <- rlm_module( +#' signature = "document, question -> answer", +#' runner = runner +#' ) +#' +#' # Use it for context exploration +#' long_doc <- paste(readLines("large_file.txt"), collapse = "\n") +#' result <- run(rlm, document = long_doc, question = "What are the main themes?", .llm = llm) +#' +#' # Enable recursive LLM calls for complex reasoning +#' rlm_recursive <- rlm_module( +#' signature = "document -> summary", +#' runner = runner, +#' sub_lm = ellmer::chat_openai(model = "gpt-4o-mini"), +#' max_llm_calls = 10 +#' ) +#' } +#' +#' @name module-rlm +NULL + + +#' Create a Recursive Language Model (RLM) Module +#' +#' @description +#' Factory function to create an RLMModule that enables LLMs to programmatically +#' explore large contexts through a REPL interface. +#' +#' @param signature A Signature object or string notation defining inputs/outputs +#' @param runner An RCodeRunner object for code execution. Required. +#' @param max_iterations Maximum REPL iterations before fallback (default 20) +#' @param max_llm_calls Maximum recursive LLM calls allowed (default 50) +#' @param max_output_chars Maximum characters per execution output (default 100000) +#' @param sub_lm Optional ellmer Chat for recursive queries. NULL = disabled. +#' @param verbose Logical. Print execution progress (default FALSE) +#' @param tools Named list of user-defined R functions to inject into REPL. +#' Each tool becomes available as a function in the code execution environment. +#' Non-function values in the list will cause an error. +#' @param ... Additional arguments passed to the module +#' +#' @return An RLMModule object +#' +#' @export +#' @examples +#' \dontrun{ +#' runner <- r_code_runner(timeout = 30) +#' rlm <- rlm_module("question -> answer", runner = runner) +#' result <- run(rlm, question = "What is the 10th Fibonacci number?", .llm = llm) +#' } +rlm_module <- function( + signature, + runner, + max_iterations = 20L, + max_llm_calls = 50L, + max_output_chars = 100000L, + sub_lm = NULL, + verbose = FALSE, + tools = list(), + ... +) { + # Validate runner + if (missing(runner) || is.null(runner)) { + cli::cli_abort(c( + "RLM requires an explicit runner for code execution", + "i" = "Create one with: {.code runner <- r_code_runner()}", + "i" = "Then pass it: {.code rlm_module(..., runner = runner)}" + )) + } + + if (!inherits(runner, "RCodeRunner")) { + cli::cli_abort(c( + "runner must be an RCodeRunner object", + "x" = "You provided: {.cls {class(runner)[1]}}", + "i" = "Create one with: {.code r_code_runner()}" + )) + } + + # Parse signature if string + if (is.character(signature)) { + signature <- signature(signature) + } + + if (!S7::S7_inherits(signature, Signature)) { + cli::cli_abort(c( + "signature must be a Signature object or string notation", + "x" = "You provided: {.cls {class(signature)[1]}}" + )) + } + + # Validate bounds for iterations and calls + max_iterations <- as.integer(max_iterations) + max_llm_calls <- as.integer(max_llm_calls) + + if (max_iterations < 1L) { + cli::cli_abort(c( + "max_iterations must be at least 1", + "x" = "You provided: {.val {max_iterations}}" + )) + } + + if (max_llm_calls < 0L) { + cli::cli_abort(c( + "max_llm_calls must be non-negative", + "x" = "You provided: {.val {max_llm_calls}}" + )) + } + + # Validate tools + if (!is.list(tools)) { + cli::cli_abort(c( + "tools must be a named list of functions", + "x" = "You provided: {.cls {class(tools)[1]}}" + )) + } + + if (length(tools) > 0 && is.null(names(tools))) { + cli::cli_abort(c( + "tools must be a named list", + "i" = "Example: {.code tools = list(my_func = function(...) ...)}" + )) + } + + # Validate all tools are functions + if (length(tools) > 0) { + non_functions <- vapply(tools, Negate(is.function), logical(1)) + if (any(non_functions)) { + bad_names <- names(tools)[non_functions] + cli::cli_abort(c( + "All tools must be functions", + "x" = "Non-function tool{?s}: {.val {bad_names}}" + )) + } + } + + RLMModule$new( + signature = signature, + runner = runner, + max_iterations = max_iterations, + max_llm_calls = max_llm_calls, + max_output_chars = as.integer(max_output_chars), + sub_lm = sub_lm, + verbose = verbose, + tools = tools, + ... + ) +} + + +#' RLM Module R6 Class +#' +#' @description +#' R6 class implementing the Recursive Language Model pattern: LLM-driven +#' REPL exploration of context with programmatic tools. +#' +#' @keywords internal +#' @noRd +RLMModule <- R6::R6Class( + "RLMModule", + inherit = Module, + public = list( + #' @field runner RCodeRunner for code execution + runner = NULL, + + #' @field max_iterations Maximum REPL iterations before fallback + max_iterations = NULL, + + #' @field max_llm_calls Maximum recursive LLM calls + max_llm_calls = NULL, + + #' @field max_output_chars Maximum output size per execution + max_output_chars = NULL, + + #' @field sub_lm Optional LLM for recursive queries + sub_lm = NULL, + + #' @field verbose Whether to print execution progress + verbose = NULL, + + #' @field tools User-defined REPL tools + tools = NULL, + + #' @description + #' Initialize an RLMModule + #' + #' @param signature Signature object defining inputs/outputs + #' @param runner RCodeRunner for code execution + #' @param max_iterations Maximum REPL iterations + #' @param max_llm_calls Maximum recursive LLM calls + #' @param max_output_chars Maximum output size + #' @param sub_lm Optional LLM for recursive queries + #' @param verbose Whether to print progress + #' @param tools User-defined tools + #' @param config Optional configuration list + #' @param chat Optional ellmer Chat object + initialize = function( + signature, + runner, + max_iterations = 20L, + max_llm_calls = 50L, + max_output_chars = 100000L, + sub_lm = NULL, + verbose = FALSE, + tools = list(), + config = list(), + chat = NULL + ) { + super$initialize( + signature = signature, + config = config, + chat = chat + ) + + self$runner <- runner + self$max_iterations <- as.integer(max_iterations) + self$max_llm_calls <- as.integer(max_llm_calls) + self$max_output_chars <- as.integer(max_output_chars) + self$sub_lm <- sub_lm + self$verbose <- verbose + self$tools <- tools + + # Store REPL history per execution + self$state$repl_history <- list() + }, + + #' @description + #' Execute the RLM workflow + #' + #' @param batch Named list or data frame of inputs + #' @param .llm Optional ellmer chat object + #' @param trace Logical whether to record trace information + #' @param ... Additional arguments + #' @return Tibble with output, chat, metadata columns + forward = function(batch, .llm = NULL, trace = TRUE, ...) { + # Handle inputs + if (is.data.frame(batch)) { + inputs <- as.list(batch[1, , drop = FALSE]) + } else { + inputs <- batch + } + + # Get LLM - clone for fresh conversation + base_llm <- .llm %||% self$chat %||% get_default_chat() + if (is.null(base_llm)) { + cli::cli_abort("No LLM provided. Pass .llm or set a default chat.") + } + + llm <- base_llm$clone() + + start_time <- Sys.time() + + # Initialize call counter (shared across recursions) + call_counter <- new.env() + call_counter$count <- 0L + + # Build context description for system prompt + context_desc <- private$describe_context(inputs) + + # Build system prompt + system_prompt <- private$build_system_prompt(context_desc) + + # REPL iteration loop + history <- list() + final_answer <- NULL + + for (iter in seq_len(self$max_iterations)) { + if (self$verbose) { + cli::cli_alert_info("RLM Iteration {iter}/{self$max_iterations}") + } + + # Build prompt for this iteration + prompt <- private$build_iteration_prompt(system_prompt, history, iter) + + # Get LLM response (code generation) + response <- private$get_code_response(llm, prompt) + + if (self$verbose) { + cli::cli_alert("Code generated: {substr(response$code, 1, 100)}...") + } + + # Execute code with RLM tools injected + exec_result <- private$execute_with_rlm_tools( + response$code, + inputs, + call_counter + ) + + # Record in history + history[[iter]] <- list( + iteration = iter, + reasoning = response$reasoning, + code = response$code, + output = exec_result$formatted_output, + success = exec_result$success, + is_final = exec_result$is_final + ) + + # Check for SUBMIT() termination + if (exec_result$is_final) { + final_answer <- exec_result$final_value + if (self$verbose) { + cli::cli_alert_success("SUBMIT called with answer") + } + break + } + + # Check for errors - feed back to LLM for retry + if (!exec_result$success) { + if (self$verbose) { + cli::cli_alert_warning( + "Iteration {iter}: Code execution failed - {exec_result$error}" + ) + } + # Error will be in history, LLM can see and fix + next + } + + if (self$verbose) { + cli::cli_alert( + "Output: {substr(exec_result$formatted_output, 1, 200)}..." + ) + } + } + + # Fallback extract if no SUBMIT() + if (is.null(final_answer)) { + cli::cli_warn(c( + "RLM reached max_iterations ({self$max_iterations}) without SUBMIT()", + "i" = "Using fallback extraction from trajectory", + "i" = "Consider increasing max_iterations or simplifying the query" + )) + final_answer <- private$extract_fallback(inputs, history, llm) + } + + # Store REPL history + if (trace) { + self$state$repl_history <- c( + self$state$repl_history, + list(list( + timestamp = start_time, + inputs = inputs, + history = history, + final_answer = final_answer, + iterations_used = length(history), + llm_calls_used = call_counter$count + )) + ) + } + + # Build output matching signature + output <- private$build_output(final_answer) + + duration_secs <- as.numeric(difftime( + Sys.time(), + start_time, + units = "secs" + )) + duration_ms <- duration_secs * 1000 + + # Build metadata + metadata <- list( + model = "rlm", + iterations = length(history), + max_iterations = self$max_iterations, + llm_calls = call_counter$count, + max_llm_calls = self$max_llm_calls, + duration_ms = round(duration_ms, 2), + repl_history = history + ) + + tibble::tibble( + output = list(output), + chat = list(llm), + metadata = list(metadata) + ) + }, + + #' @description + #' Get REPL history for inspection + #' @return List of REPL execution records + get_repl_history = function() { + self$state$repl_history + }, + + #' @description + #' Create a fresh copy of this module + #' @return New RLMModule with same settings + reset_copy = function() { + RLMModule$new( + signature = self$signature, + runner = self$runner, + max_iterations = self$max_iterations, + max_llm_calls = self$max_llm_calls, + max_output_chars = self$max_output_chars, + sub_lm = self$sub_lm, + verbose = self$verbose, + tools = self$tools, + config = self$config, + chat = self$chat + ) + }, + + #' @description + #' Print method for RLMModule + print = function() { + # Format signature + input_names <- vapply( + self$signature@inputs, + function(x) x$name, + character(1) + ) + sig_str <- paste0( + paste(input_names, collapse = ", "), + " -> ", + private$get_output_names() + ) + + cli::cli_h3("RLMModule") + cli::cli_bullets(c( + "*" = "Signature: {sig_str}", + "*" = "Max iterations: {.val {self$max_iterations}}", + "*" = "Max LLM calls: {.val {self$max_llm_calls}}", + "*" = "Runner timeout: {.val {self$runner$timeout}}s", + "*" = "Recursive queries: {.val {if (is.null(self$sub_lm)) 'disabled' else 'enabled'}}", + "*" = "Custom tools: {.val {length(self$tools)}}" + )) + invisible(self) + } + ), + + private = list( + #' Get output field names from signature + get_output_names = function() { + output_type <- self$signature@output_type + if (methods::.hasSlot(output_type, "properties")) { + props <- output_type@properties + if (length(props) > 0) { + return(paste(names(props), collapse = ", ")) + } + } + "answer" + }, + + #' Describe context variables for the system prompt + describe_context = function(inputs) { + if (length(inputs) == 0) { + return("No context variables available.") + } + + descriptions <- vapply( + names(inputs), + function(name) { + val <- inputs[[name]] + type_str <- class(val)[1] + size_str <- if (is.character(val)) { + total_chars <- sum(nchar(val)) + paste0(total_chars, " characters") + } else if (is.data.frame(val)) { + paste0(nrow(val), " rows x ", ncol(val), " cols") + } else if (is.list(val)) { + paste0(length(val), " elements") + } else if (is.vector(val)) { + paste0(length(val), " elements") + } else { + "1 object" + } + + preview <- if (is.character(val) && length(val) == 1) { + if (nchar(val) > 100) { + paste0(substr(val, 1, 100), "...") + } else { + val + } + } else { + deparse(val, width.cutoff = 60)[1] + } + + paste0( + "- `.context$", + name, + "` (", + type_str, + ", ", + size_str, + ")\n", + " Preview: ", + preview + ) + }, + character(1) + ) + + paste(descriptions, collapse = "\n\n") + }, + + #' Build the system prompt for RLM + build_system_prompt = function(context_desc) { + has_sub_lm <- !is.null(self$sub_lm) + + # Build tool descriptions + tool_desc <- c( + "- `SUBMIT(answer)`: Submit your final answer and terminate", + "- `peek(var, start = 1, end = 1000)`: View a character slice of a variable", + "- `search(var, pattern)`: Regex search in variable, returns matches" + ) + + if (has_sub_lm) { + tool_desc <- c( + tool_desc, + "- `rlm_query(query, context_slice = NULL)`: Ask a sub-question to another LLM", + "- `rlm_query_batch(queries, slices = NULL)`: Batch multiple sub-questions" + ) + } + + if (length(self$tools) > 0) { + custom_tool_names <- names(self$tools) + tool_desc <- c( + tool_desc, + paste0("- `", custom_tool_names, "()`: User-defined tool") + ) + } + + glue::glue( + " +You are working in an R REPL environment. Your goal is to answer the query by +writing R code to explore and analyze the provided context. + +## Available Variables +{context_desc} + +## Available Functions +{paste(tool_desc, collapse = ' +')} + +## Rules +1. Explore the context programmatically - don't ask to see all content at once +2. Use peek() to examine slices of large text +3. Use search() to find specific patterns +4. Break complex queries into smaller sub-questions{if (has_sub_lm) ' using rlm_query()' else ''} +5. Call SUBMIT(answer) when you have the final answer +6. You have {self$max_iterations} iterations{if (has_sub_lm) paste0(' and ', self$max_llm_calls, ' LLM calls') else ''} + +## Response Format +Return JSON with: +- \"reasoning\": Your thought process for this step +- \"code\": R code to execute (single string) + +The code's output will be shown to you. Continue until you call SUBMIT(). +" + ) + }, + + #' Build prompt for a specific iteration + build_iteration_prompt = function(system_prompt, history, iter) { + if (iter == 1) { + # First iteration - just system prompt + return(system_prompt) + } + + # Build history context + history_parts <- vapply( + history, + function(h) { + glue::glue( + " +## Iteration {h$iteration} +Reasoning: {h$reasoning} + +Code: +```r +{h$code} +``` + +{if (h$success) 'Output:' else 'Error:'} +{h$output} +{if (h$is_final) '(SUBMIT was called)' else ''} +" + ) + }, + character(1) + ) + + paste0( + system_prompt, + "\n\n## Previous Iterations\n", + paste(history_parts, collapse = "\n"), + "\n\n## Next Step\nContinue exploring or call SUBMIT() with your answer." + ) + }, + + #' Get code response from LLM + get_code_response = function(llm, prompt) { + output_type <- ellmer::type_object( + reasoning = ellmer::type_string( + description = "Your thought process for this step" + ), + code = ellmer::type_string( + description = "R code to execute" + ) + ) + + result <- tryCatch( + llm$chat_structured(prompt, type = output_type), + error = function(e) { + cli::cli_abort(c( + "Failed to get code from LLM", + "x" = "Error: {e$message}" + )) + } + ) + + if (is.null(result$code) || !is.character(result$code)) { + cli::cli_abort(c( + "LLM returned invalid response", + "i" = "Missing or invalid 'code' field" + )) + } + + list( + reasoning = result$reasoning %||% "", + code = result$code + ) + }, + + #' Execute code with RLM tools injected + execute_with_rlm_tools = function(code, inputs, call_counter) { + # Build RLM prelude that defines tools + rlm_prelude <- create_rlm_prelude( + max_llm_calls = self$max_llm_calls, + has_sub_lm = !is.null(self$sub_lm), + custom_tools = self$tools + ) + + # Build combined code: prelude + user code + combined_code <- paste0( + "# RLM Prelude\n", + rlm_prelude, + "\n\n# User Code\n", + code + ) + + # Execute with inputs as context + result <- self$runner$execute(combined_code, context = inputs) + + # Validate runner result structure + if (!is.list(result)) { + cli::cli_abort(c( + "Runner returned invalid result", + "x" = "Expected list, got {.cls {class(result)[1]}}" + )) + } + + required_fields <- c("success", "result") + missing_fields <- setdiff(required_fields, names(result)) + if (length(missing_fields) > 0) { + cli::cli_abort(c( + "Runner result missing required fields", + "x" = "Missing: {.val {missing_fields}}" + )) + } + + # Detect SUBMIT termination using helper function + is_final <- is_rlm_final(result$result) + final_value <- if (is_final) { + extract_rlm_final(result$result) + } else { + NULL + } + + # Handle rlm_query requests (if sub_lm is available) + if (is_rlm_query_request(result$result) && !is.null(self$sub_lm)) { + # Process the recursive query (single or batch) + query_result <- private$process_rlm_query( + result$result, + call_counter + ) + + # Return the query result as the output + return(list( + success = query_result$success, + is_final = FALSE, + final_value = NULL, + formatted_output = query_result$formatted_output, + error = query_result$error, + raw_result = result + )) + } + + # Format output for history + formatted_output <- if (result$success) { + private$format_execution_output(result) + } else { + paste("Error:", result$error %||% "Unknown error") + } + + # Truncate if too long + if (nchar(formatted_output) > self$max_output_chars) { + formatted_output <- paste0( + substr(formatted_output, 1, self$max_output_chars), + "\n... [TRUNCATED]" + ) + } + + list( + success = result$success, + is_final = is_final, + final_value = final_value, + formatted_output = formatted_output, + error = result$error, + raw_result = result + ) + }, + + #' Process an rlm_query request (single or batch) + #' + #' @return List with success, formatted_output, error fields + process_rlm_query = function(request, call_counter) { + # Handle batch queries + if (isTRUE(request$batch)) { + return(private$process_rlm_query_batch(request, call_counter)) + } + + # Single query processing + # Check call limit + if (call_counter$count >= self$max_llm_calls) { + error_msg <- paste0( + "Maximum LLM calls (", + self$max_llm_calls, + ") exceeded" + ) + cli::cli_warn(error_msg) + return(list( + success = FALSE, + formatted_output = paste0("Error: ", error_msg), + error = error_msg + )) + } + + call_counter$count <- call_counter$count + 1L + + query <- request$query + context_slice <- request$context + + prompt <- if (!is.null(context_slice)) { + paste0("Context:\n", context_slice, "\n\nQuestion: ", query) + } else { + query + } + + result <- tryCatch( + { + response <- self$sub_lm$chat(prompt) + list( + success = TRUE, + formatted_output = paste0("Query result: ", response), + error = NULL + ) + }, + error = function(e) { + cli::cli_warn(c( + "Recursive LLM query failed", + "x" = "Error: {e$message}", + "i" = "Query: {substr(query, 1, 100)}..." + )) + list( + success = FALSE, + formatted_output = paste0("Query error: ", e$message), + error = e$message + ) + } + ) + + result + }, + + #' Process batch rlm_query requests + #' + #' @return List with success, formatted_output, error fields + process_rlm_query_batch = function(request, call_counter) { + queries <- request$queries + slices <- request$slices + n_queries <- length(queries) + + # Check if we have enough calls remaining + remaining_calls <- self$max_llm_calls - call_counter$count + if (n_queries > remaining_calls) { + error_msg <- paste0( + "Batch of ", + n_queries, + " queries would exceed limit. ", + "Remaining calls: ", + remaining_calls + ) + cli::cli_warn(error_msg) + return(list( + success = FALSE, + formatted_output = paste0("Error: ", error_msg), + error = error_msg + )) + } + + # Process each query + results <- vector("list", n_queries) + errors <- character() + + for (i in seq_len(n_queries)) { + call_counter$count <- call_counter$count + 1L + + query <- queries[[i]] + context_slice <- if (!is.null(slices)) slices[[i]] else NULL + + prompt <- if (!is.null(context_slice)) { + paste0("Context:\n", context_slice, "\n\nQuestion: ", query) + } else { + query + } + + results[[i]] <- tryCatch( + { + self$sub_lm$chat(prompt) + }, + error = function(e) { + errors <<- c(errors, paste0("Query ", i, ": ", e$message)) + paste0("[Error: ", e$message, "]") + } + ) + } + + # Format output + formatted_parts <- vapply( + seq_len(n_queries), + function(i) paste0("Query ", i, " result: ", results[[i]]), + character(1) + ) + formatted_output <- paste(formatted_parts, collapse = "\n\n") + + if (length(errors) > 0) { + cli::cli_warn(c( + "Some batch queries failed", + "x" = errors + )) + } + + list( + success = length(errors) == 0, + formatted_output = formatted_output, + error = if (length(errors) > 0) paste(errors, collapse = "; ") else NULL + ) + }, + + #' Format execution output for history + format_execution_output = function(result) { + parts <- character() + + # Safely check stdout (may be NULL or missing) + stdout_val <- result$stdout + if ( + !is.null(stdout_val) && + is.character(stdout_val) && + nchar(stdout_val) > 0 + ) { + parts <- c(parts, paste0("stdout:\n", stdout_val)) + } + + # Safely check messages (may be NULL or missing) + messages_val <- result$messages + if ( + !is.null(messages_val) && + is.character(messages_val) && + nchar(messages_val) > 0 + ) { + parts <- c(parts, paste0("messages:\n", messages_val)) + } + + # Safely check warnings (may be NULL or missing) + warnings_val <- result$warnings + if ( + !is.null(warnings_val) && + is.character(warnings_val) && + nchar(warnings_val) > 0 + ) { + parts <- c(parts, paste0("warnings:\n", warnings_val)) + } + + if (!is.null(result$result)) { + result_str <- tryCatch( + { + if (is.data.frame(result$result)) { + paste( + utils::capture.output(print(result$result)), + collapse = "\n" + ) + } else if ( + is.atomic(result$result) && length(result$result) <= 10 + ) { + paste(result$result, collapse = ", ") + } else { + paste(utils::capture.output(str(result$result)), collapse = "\n") + } + }, + error = function(e) deparse(result$result)[1] + ) + parts <- c(parts, paste0("result:\n", result_str)) + } + + if (length(parts) == 0) { + return("[No output]") + } + + paste(parts, collapse = "\n\n") + }, + + #' Extract answer via fallback when max_iterations reached + extract_fallback = function(inputs, history, llm) { + # Build trajectory summary + trajectory <- vapply( + history, + function(h) { + glue::glue( + "Iteration {h$iteration}: +Reasoning: {h$reasoning} +Code: {h$code} +Output: {substr(h$output, 1, 500)}" + ) + }, + character(1) + ) + + input_context <- private$format_inputs_for_prompt(inputs) + + prompt <- glue::glue( + " +The RLM agent ran out of iterations before calling SUBMIT(). +Based on the exploration trajectory below, extract the best possible answer. + +## Original Query +{input_context} + +## Exploration Trajectory +{paste(trajectory, collapse = ' +--- +')} + +## Task +Based on the above exploration, provide the final answer to the original query. +Be concise and direct. If the exploration was incomplete, provide the best +answer possible with what was discovered. +" + ) + + tryCatch( + llm$chat(prompt), + error = function(e) { + cli::cli_warn(c( + "Fallback extraction failed", + "x" = "Error: {e$message}", + "i" = "Returning error message as answer" + )) + paste0("[Fallback extraction failed: ", e$message, "]") + } + ) + }, + + #' Format inputs for prompt display + format_inputs_for_prompt = function(inputs) { + parts <- vapply( + names(inputs), + function(name) { + val <- inputs[[name]] + if (is.character(val) && length(val) == 1) { + if (nchar(val) > 500) { + val <- paste0(substr(val, 1, 500), "... [truncated]") + } + paste0(name, ": ", val) + } else { + paste0(name, ": ", deparse(val, width.cutoff = 500)[1]) + } + }, + character(1) + ) + paste(parts, collapse = "\n") + }, + + #' Build output matching signature + build_output = function(answer) { + output_type <- self$signature@output_type + + if (methods::.hasSlot(output_type, "properties")) { + props <- output_type@properties + if (length(props) == 1) { + return(setNames(list(answer), names(props)[1])) + } + } + + list(answer = answer) + } + ) +) diff --git a/R/module.R b/R/module.R index 1c4862a..6744756 100644 --- a/R/module.R +++ b/R/module.R @@ -15,6 +15,7 @@ #' - `"multichain"`: MultiChainComparison module for ensemble reasoning #' - `"program_of_thought"`: Code execution module (requires runner) #' - `"codeact"`: Hybrid agent with tools + code execution (requires runner) +#' - `"rlm"`: Recursive Language Model for REPL-based context exploration (requires runner) #' @param tools Optional list of ellmer ToolDef objects for react modules. #' If provided with `type = "predict"`, automatically upgrades to react. #' @param max_iterations Maximum ReAct iterations (default: 10, only for react) @@ -124,7 +125,8 @@ module <- function( "chain_of_thought", "multichain", "program_of_thought", - "codeact" + "codeact", + "rlm" ) ) @@ -195,6 +197,24 @@ module <- function( chat = chat ) }, + rlm = { + if (is.null(runner)) { + cli::cli_abort(c( + "rlm requires a runner", + "i" = "Create one with: {.code runner <- r_code_runner()}", + "i" = "Then pass it: {.code module(..., runner = runner)}" + )) + } + rlm_module( + signature = signature, + runner = runner, + max_iterations = max_iterations, + tools = tools %||% list(), + config = config, + chat = chat, + ... + ) + }, cli::cli_abort("Unknown module type: {type}") ) } diff --git a/R/rlm-tools.R b/R/rlm-tools.R new file mode 100644 index 0000000..2b93a8f --- /dev/null +++ b/R/rlm-tools.R @@ -0,0 +1,239 @@ +#' RLM Tools - Prelude Generator +#' +#' @description +#' Generates R code that defines RLM tools in the execution environment. +#' This code is run before user-generated code via RCodeRunner. +#' +#' @details +#' The prelude defines these functions in the execution environment: +#' +#' - `SUBMIT(answer)`: Terminate and return final answer +#' - `peek(var, start, end)`: View a slice of a variable +#' - `search(var, pattern)`: Regex search in variable +#' - `rlm_query(query, context_slice)`: Request a recursive LLM call (returns marker for interception) +#' - `rlm_query_batch(queries, slices)`: Request batched LLM calls (returns marker for interception) +#' +#' The `rlm_query` and `rlm_query_batch` functions return special marker objects +#' that the main RLM process intercepts and handles. The actual LLM calls happen +#' in the parent R process, not in the sandboxed code execution environment. +#' +#' @keywords internal +#' @name rlm-tools +NULL + + +#' Create RLM Prelude Code +#' +#' @description +#' Generates R code that defines RLM tools in the execution environment. +#' +#' @param max_llm_calls Maximum allowed recursive LLM calls +#' @param has_sub_lm Logical indicating if recursive queries are enabled +#' @param custom_tools Named list of user-defined R functions +#' +#' @return Character string of R code defining RLM tools +#' +#' @keywords internal +#' @noRd +create_rlm_prelude <- function( + max_llm_calls = 50L, + has_sub_lm = FALSE, + custom_tools = list() +) { + # Base tools (always available) + base_prelude <- ' +# ============================================ +# RLM Tools - Injected by dsprrr +# ============================================ + +# SUBMIT: Terminate and return final answer +# When code calls SUBMIT(answer), it returns a special object +# that the main process detects to stop iteration +SUBMIT <- function(answer) { + result <- answer + class(result) <- c("rlm_final", class(result)) + attr(result, "rlm_final") <- TRUE + result +} + +# peek: View a slice of a character variable +# Useful for exploring large text contexts +peek <- function(var, start = 1L, end = 1000L) { + if (!is.character(var)) { + var <- as.character(var) + } + + if (length(var) > 1) { + # For character vectors, show elements in range + n <- length(var) + start <- max(1L, as.integer(start)) + end <- min(n, as.integer(end)) + return(var[start:end]) + } + + # For single strings, show character range + total_chars <- nchar(var) + start <- max(1L, as.integer(start)) + end <- min(total_chars, as.integer(end)) + + if (start > total_chars) { + return("") + } + + substr(var, start, end) +} + +# search: Regex search in a variable +# Returns all matches as a character vector +search <- function(var, pattern, ignore_case = FALSE) { + if (!is.character(var)) { + var <- as.character(var) + } + + # Collapse to single string if vector + if (length(var) > 1) { + var <- paste(var, collapse = "\\n") + } + + # Find all matches + matches <- regmatches( + var, + gregexpr(pattern, var, ignore.case = ignore_case, perl = TRUE) + ) + + unlist(matches) +} +' + + # Recursive query tools (only if sub_lm is available) + if (has_sub_lm) { + recursive_prelude <- sprintf( + ' +# rlm_query: Recursive LLM query +# Returns a request marker - main process will intercept and handle +rlm_query <- function(query, context_slice = NULL) { + # Note: This function returns a marker that the main process intercepts + + # The actual LLM call happens in the parent R process + structure( + list(query = query, context = context_slice, batch = FALSE), + class = "rlm_query_request" + ) +} + +# rlm_query_batch: Batched recursive queries +# Returns a request marker for batch processing +rlm_query_batch <- function(queries, slices = NULL) { + if (!is.character(queries)) { + stop("queries must be a character vector") + } + + if (!is.null(slices) && length(slices) != length(queries)) { + stop("slices must have same length as queries") + } + + structure( + list(queries = queries, slices = slices, batch = TRUE), + class = "rlm_query_request" + ) +} + +# Note: Maximum LLM calls allowed: %d +# Exceeding this limit will result in an error +', + max_llm_calls + ) + } else { + recursive_prelude <- ' +# rlm_query: Disabled (no sub_lm provided) +rlm_query <- function(query, context_slice = NULL) { + stop("Recursive LLM queries are disabled. Provide sub_lm to enable.") +} + +rlm_query_batch <- function(queries, slices = NULL) { + stop("Recursive LLM queries are disabled. Provide sub_lm to enable.") +} +' + } + + # Custom tools + custom_prelude <- "" + if (length(custom_tools) > 0) { + # Serialize each custom function + tool_defs <- vapply( + names(custom_tools), + function(name) { + fn <- custom_tools[[name]] + if (!is.function(fn)) { + # This should not happen if rlm_module() validation is working + # but provide a clear error in the prelude as defense in depth + return(sprintf( + "%s <- function(...) stop('Tool %s is not a function')\n", + name, + name + )) + } + + # Deparse the function and assign it + fn_body <- paste(deparse(fn), collapse = "\n") + sprintf("%s <- %s\n", name, fn_body) + }, + character(1) + ) + + custom_prelude <- paste0( + "\n# Custom Tools\n", + paste(tool_defs, collapse = "\n") + ) + } + + # Combine all parts + paste0( + base_prelude, + recursive_prelude, + custom_prelude, + "\n# ============================================\n" + ) +} + + +#' Check if a value is an RLM final answer +#' +#' @param x Value to check +#' @return Logical indicating if x is an rlm_final value +#' +#' @keywords internal +#' @noRd +is_rlm_final <- function(x) { + inherits(x, "rlm_final") || isTRUE(attr(x, "rlm_final")) +} + + +#' Check if a value is an RLM query request +#' +#' @param x Value to check +#' @return Logical indicating if x is an rlm_query_request +#' +#' @keywords internal +#' @noRd +is_rlm_query_request <- function(x) { + inherits(x, "rlm_query_request") +} + + +#' Extract value from RLM final answer +#' +#' @param x An rlm_final value +#' @return The underlying value with rlm_final class removed +#' +#' @keywords internal +#' @noRd +extract_rlm_final <- function(x) { + if (!is_rlm_final(x)) { + return(x) + } + + class(x) <- setdiff(class(x), "rlm_final") + attr(x, "rlm_final") <- NULL + x +} diff --git a/_pkgdown.yml b/_pkgdown.yml index 81b8fd9..c6d8d32 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -212,6 +212,8 @@ reference: - module-program-of-thought - code_act - module-codeact + - rlm_module + - module-rlm - title: Chat-Centric API desc: ellmer-style pipe-friendly functions diff --git a/man/module-rlm.Rd b/man/module-rlm.Rd new file mode 100644 index 0000000..ddaf0c9 --- /dev/null +++ b/man/module-rlm.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/module-rlm.R +\name{module-rlm} +\alias{module-rlm} +\title{Recursive Language Model (RLM) Module} +\description{ +A module that transforms context from "input" to "environment", enabling LLMs +to programmatically explore large contexts through a REPL interface rather +than embedding them in prompts. +} +\details{ +Instead of \code{llm(prompt, context=huge_document)}, RLM stores context as R +variables that the LLM can peek, slice, search, and recursively query. + +The execution flow is: +\enumerate{ +\item Context is made available as variables in an R execution environment +\item LLM generates R code to explore and analyze the context +\item Code is executed in an isolated subprocess via RCodeRunner +\item Results are fed back to the LLM for the next iteration +\item Process continues until SUBMIT() is called or max_iterations reached +\item If max_iterations reached without SUBMIT(), fallback extraction is used +} + +Available REPL tools: +\itemize{ +\item \code{SUBMIT(answer)}: Terminate and return final answer +\item \code{peek(var, start, end)}: View a slice of a variable (default: first 1000 chars) +\item \code{search(var, pattern)}: Regex search in variable +\item \code{rlm_query(query, context_slice)}: Recursive LLM call (requires sub_lm) +\item \code{rlm_query_batch(queries, slices)}: Batched recursive calls +} + +Security: Code execution requires explicit opt-in via a runner parameter. +The runner provides subprocess isolation but is NOT a security sandbox. +For untrusted inputs, use OS-level sandboxing (containers, AppArmor). +} +\examples{ +\dontrun{ +# Create a runner (required for code execution) +runner <- r_code_runner(timeout = 30) + +# Create an RLM module for exploring large documents +rlm <- rlm_module( + signature = "document, question -> answer", + runner = runner +) + +# Use it for context exploration +long_doc <- paste(readLines("large_file.txt"), collapse = "\n") +result <- run(rlm, document = long_doc, question = "What are the main themes?", .llm = llm) + +# Enable recursive LLM calls for complex reasoning +rlm_recursive <- rlm_module( + signature = "document -> summary", + runner = runner, + sub_lm = ellmer::chat_openai(model = "gpt-4o-mini"), + max_llm_calls = 10 +) +} + +} diff --git a/man/module.Rd b/man/module.Rd index 441bb5d..216c365 100644 --- a/man/module.Rd +++ b/man/module.Rd @@ -32,6 +32,7 @@ module( \item \code{"multichain"}: MultiChainComparison module for ensemble reasoning \item \code{"program_of_thought"}: Code execution module (requires runner) \item \code{"codeact"}: Hybrid agent with tools + code execution (requires runner) +\item \code{"rlm"}: Recursive Language Model for REPL-based context exploration (requires runner) }} \item{tools}{Optional list of ellmer ToolDef objects for react modules. diff --git a/man/rlm-tools.Rd b/man/rlm-tools.Rd new file mode 100644 index 0000000..16bfce8 --- /dev/null +++ b/man/rlm-tools.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rlm-tools.R +\name{rlm-tools} +\alias{rlm-tools} +\title{RLM Tools - Prelude Generator} +\description{ +Generates R code that defines RLM tools in the execution environment. +This code is run before user-generated code via RCodeRunner. +} +\details{ +The prelude defines these functions in the execution environment: +\itemize{ +\item \code{SUBMIT(answer)}: Terminate and return final answer +\item \code{peek(var, start, end)}: View a slice of a variable +\item \code{search(var, pattern)}: Regex search in variable +\item \code{rlm_query(query, context_slice)}: Request a recursive LLM call (returns marker for interception) +\item \code{rlm_query_batch(queries, slices)}: Request batched LLM calls (returns marker for interception) +} + +The \code{rlm_query} and \code{rlm_query_batch} functions return special marker objects +that the main RLM process intercepts and handles. The actual LLM calls happen +in the parent R process, not in the sandboxed code execution environment. +} +\keyword{internal} diff --git a/man/rlm_module.Rd b/man/rlm_module.Rd new file mode 100644 index 0000000..9b7eff1 --- /dev/null +++ b/man/rlm_module.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/module-rlm.R +\name{rlm_module} +\alias{rlm_module} +\title{Create a Recursive Language Model (RLM) Module} +\usage{ +rlm_module( + signature, + runner, + max_iterations = 20L, + max_llm_calls = 50L, + max_output_chars = 100000L, + sub_lm = NULL, + verbose = FALSE, + tools = list(), + ... +) +} +\arguments{ +\item{signature}{A Signature object or string notation defining inputs/outputs} + +\item{runner}{An RCodeRunner object for code execution. Required.} + +\item{max_iterations}{Maximum REPL iterations before fallback (default 20)} + +\item{max_llm_calls}{Maximum recursive LLM calls allowed (default 50)} + +\item{max_output_chars}{Maximum characters per execution output (default 100000)} + +\item{sub_lm}{Optional ellmer Chat for recursive queries. NULL = disabled.} + +\item{verbose}{Logical. Print execution progress (default FALSE)} + +\item{tools}{Named list of user-defined R functions to inject into REPL. +Each tool becomes available as a function in the code execution environment. +Non-function values in the list will cause an error.} + +\item{...}{Additional arguments passed to the module} +} +\value{ +An RLMModule object +} +\description{ +Factory function to create an RLMModule that enables LLMs to programmatically +explore large contexts through a REPL interface. +} +\examples{ +\dontrun{ +runner <- r_code_runner(timeout = 30) +rlm <- rlm_module("question -> answer", runner = runner) +result <- run(rlm, question = "What is the 10th Fibonacci number?", .llm = llm) +} +} diff --git a/tests/testthat/test-module-rlm.R b/tests/testthat/test-module-rlm.R new file mode 100644 index 0000000..b498a20 --- /dev/null +++ b/tests/testthat/test-module-rlm.R @@ -0,0 +1,935 @@ +# Tests for RLMModule (Recursive Language Model) + +# Helper: Create a mock LLM for RLM testing +create_mock_rlm_llm <- function(code_responses = list()) { + call_count <- 0 + mock <- list( + clone = function() { + create_mock_rlm_llm(code_responses) + }, + chat_structured = function(prompt, type, ...) { + call_count <<- call_count + 1 + if (call_count <= length(code_responses)) { + code_responses[[call_count]] + } else { + # Default: simple SUBMIT + list( + reasoning = "Returning default answer", + code = "SUBMIT('default answer')" + ) + } + }, + chat = function(prompt, ...) { + "fallback answer" + } + ) + mock +} + +# ============================================================================ +# Factory Function Tests +# ============================================================================ + +test_that("rlm_module requires runner", { + expect_error( + rlm_module("question -> answer"), + "RLM requires an explicit runner" + ) +}) + +test_that("rlm_module validates runner type", { + expect_error( + rlm_module("question -> answer", runner = "not a runner"), + "runner must be an RCodeRunner" + ) +}) + +test_that("rlm_module creates RLMModule", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module("question -> answer", runner = runner) + + expect_s3_class(rlm, "RLMModule") + expect_s3_class(rlm, "Module") +}) + +test_that("rlm_module accepts string signature", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module("document, question -> answer", runner = runner) + + expect_equal(length(rlm$signature@inputs), 2) +}) + +test_that("rlm_module accepts Signature object", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + sig <- signature("question -> answer") + rlm <- rlm_module(sig, runner = runner) + + expect_s3_class(rlm, "RLMModule") +}) + +test_that("rlm_module respects max_iterations", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 10 + ) + + expect_equal(rlm$max_iterations, 10L) +}) + +test_that("rlm_module respects max_llm_calls", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_llm_calls = 25 + ) + + expect_equal(rlm$max_llm_calls, 25L) +}) + +test_that("rlm_module validates tools parameter", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + + # Must be a list + expect_error( + rlm_module("question -> answer", runner = runner, tools = "not a list"), + "tools must be a named list" + ) + + # If not empty, must be named + expect_error( + rlm_module( + "question -> answer", + runner = runner, + tools = list(function() {}) + ), + "tools must be a named list" + ) +}) + +test_that("rlm_module validates all tools are functions", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + + # Non-function in tools list + expect_error( + rlm_module( + "question -> answer", + runner = runner, + tools = list(good = function() {}, bad = "not a function") + ), + "All tools must be functions" + ) + + # Multiple non-functions + expect_error( + rlm_module( + "question -> answer", + runner = runner, + tools = list(a = 1, b = "string", c = function() {}) + ), + "All tools must be functions" + ) +}) + +test_that("rlm_module validates max_iterations bounds", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + + expect_error( + rlm_module("question -> answer", runner = runner, max_iterations = 0), + "max_iterations must be at least 1" + ) + + expect_error( + rlm_module("question -> answer", runner = runner, max_iterations = -5), + "max_iterations must be at least 1" + ) +}) + +test_that("rlm_module validates max_llm_calls bounds", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + + expect_error( + rlm_module("question -> answer", runner = runner, max_llm_calls = -1), + "max_llm_calls must be non-negative" + ) + + # Zero is allowed (disables recursive calls) + expect_no_error( + rlm_module("question -> answer", runner = runner, max_llm_calls = 0) + ) +}) + +# ============================================================================ +# Module Structure Tests +# ============================================================================ + +test_that("RLMModule has correct fields", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module("question -> answer", runner = runner) + + expect_true(!is.null(rlm$runner)) + expect_true(!is.null(rlm$max_iterations)) + expect_true(!is.null(rlm$max_llm_calls)) + expect_true(!is.null(rlm$max_output_chars)) + expect_true(!is.null(rlm$signature)) + expect_false(rlm$verbose) + expect_null(rlm$sub_lm) +}) + +test_that("RLMModule get_repl_history returns list", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module("question -> answer", runner = runner) + + history <- rlm$get_repl_history() + + expect_type(history, "list") + expect_length(history, 0) # No executions yet +}) + +test_that("RLMModule reset_copy creates fresh module", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 15 + ) + + copy <- rlm$reset_copy() + + expect_s3_class(copy, "RLMModule") + expect_equal(copy$max_iterations, 15L) + expect_length(copy$state$repl_history, 0) +}) + +test_that("RLMModule print works", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- rlm_module("question -> answer", runner = runner) + + expect_invisible(print(rlm)) +}) + +# ============================================================================ +# SUBMIT Termination Tests +# ============================================================================ + +test_that("RLMModule terminates on SUBMIT call", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + # Mock LLM that immediately calls SUBMIT + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Computing the answer", + code = "SUBMIT(42)" + ) + )) + + result <- rlm$forward( + list(question = "What is 6 * 7?"), + .llm = mock_llm + ) + + expect_s3_class(result, "tbl_df") + expect_true("output" %in% names(result)) + expect_true("metadata" %in% names(result)) + expect_equal(result$metadata[[1]]$iterations, 1) +}) + +test_that("RLMModule SUBMIT returns correct value", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Simple string answer", + code = "SUBMIT('the answer is 42')" + ) + )) + + result <- rlm$forward( + list(question = "What is the meaning of life?"), + .llm = mock_llm + ) + + # The output should contain the submitted answer + expect_equal(result$output[[1]]$answer, "the answer is 42") +}) + +test_that("RLMModule SUBMIT works with complex values", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Returning a list", + code = "SUBMIT(list(value = 42, unit = 'answer'))" + ) + )) + + result <- rlm$forward( + list(question = "Give me structured data"), + .llm = mock_llm + ) + + expect_type(result$output[[1]]$answer, "list") + expect_equal(result$output[[1]]$answer$value, 42) +}) + +# ============================================================================ +# REPL Tools Tests +# ============================================================================ + +test_that("RLM peek function works in REPL", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "document -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Peek at document", + code = "first_100 <- peek(.context$document, 1, 100)\nSUBMIT(first_100)" + ) + )) + + long_doc <- paste(rep("Hello world. ", 100), collapse = "") + + result <- rlm$forward( + list(document = long_doc), + .llm = mock_llm + ) + + # Should get first 100 characters + expect_equal(nchar(result$output[[1]]$answer), 100) + expect_true(startsWith(result$output[[1]]$answer, "Hello world")) +}) + +test_that("RLM search function works in REPL", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "document -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Search for emails", + code = "matches <- search(.context$document, '[a-z]+@[a-z]+\\\\.[a-z]+')\nSUBMIT(paste(matches, collapse=', '))" + ) + )) + + doc <- "Contact us at test@example.com or info@example.org for more info." + + result <- rlm$forward( + list(document = doc), + .llm = mock_llm + ) + + expect_true(grepl("test@example.com", result$output[[1]]$answer)) + expect_true(grepl("info@example.org", result$output[[1]]$answer)) +}) + +# ============================================================================ +# Multiple Iterations Tests +# ============================================================================ + +test_that("RLMModule supports multiple iterations", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + # First iteration explores, second submits + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "First, let me compute something", + code = "x <- 10 + 32" + ), + list( + reasoning = "Now submit the answer", + code = "SUBMIT(42)" + ) + )) + + result <- rlm$forward( + list(question = "What is 10 + 32?"), + .llm = mock_llm + ) + + expect_equal(result$metadata[[1]]$iterations, 2) +}) + +test_that("RLMModule respects max_iterations", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 2 + ) + + # LLM never calls SUBMIT + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Step 1", code = "x <- 1"), + list(reasoning = "Step 2", code = "y <- 2"), + list(reasoning = "Step 3", code = "z <- 3") # Won't be reached + )) + + result <- rlm$forward( + list(question = "test"), + .llm = mock_llm + ) + + # Should use fallback after max_iterations + expect_equal(result$metadata[[1]]$iterations, 2) + expect_equal(result$metadata[[1]]$max_iterations, 2) +}) + +test_that("RLMModule uses fallback when no SUBMIT", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 1 + ) + + # LLM never calls SUBMIT + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Just computing", code = "1 + 1") + )) + + result <- rlm$forward( + list(question = "test"), + .llm = mock_llm + ) + + # Should have used fallback extraction + expect_equal(result$output[[1]]$answer, "fallback answer") +}) + +# ============================================================================ +# Error Handling Tests +# ============================================================================ + +test_that("RLMModule handles code execution errors", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 3 + ) + + # First iteration fails, second succeeds + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Try something", code = "stop('intentional error')"), + list(reasoning = "Fixed it", code = "SUBMIT('success')") + )) + + result <- rlm$forward( + list(question = "test"), + .llm = mock_llm + ) + + expect_equal(result$metadata[[1]]$iterations, 2) + expect_equal(result$output[[1]]$answer, "success") +}) + +# ============================================================================ +# Custom Tools Tests +# ============================================================================ + +test_that("RLMModule supports custom tools", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + + # Define a custom tool + custom_tools <- list( + double_it = function(x) x * 2 + ) + + rlm <- rlm_module( + "question -> answer", + runner = runner, + tools = custom_tools + ) + + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Using custom tool", + code = "result <- double_it(21)\nSUBMIT(result)" + ) + )) + + result <- rlm$forward( + list(question = "What is double 21?"), + .llm = mock_llm + ) + + expect_equal(result$output[[1]]$answer, 42) +}) + +# ============================================================================ +# REPL History Tests +# ============================================================================ + +test_that("RLMModule stores REPL history", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Computing", code = "SUBMIT('done')") + )) + + rlm$forward( + list(question = "test"), + .llm = mock_llm + ) + + history <- rlm$get_repl_history() + + expect_length(history, 1) + expect_true("inputs" %in% names(history[[1]])) + expect_true("history" %in% names(history[[1]])) + expect_true("final_answer" %in% names(history[[1]])) +}) + +test_that("RLMModule respects trace=FALSE", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Computing", code = "SUBMIT('done')") + )) + + rlm$forward(list(question = "test"), .llm = mock_llm, trace = FALSE) + + # No history should be stored when trace=FALSE + expect_length(rlm$get_repl_history(), 0) +}) + +# ============================================================================ +# Metadata Tests +# ============================================================================ + +test_that("RLMModule returns correct metadata", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 20, + max_llm_calls = 50 + ) + + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Direct answer", code = "SUBMIT('pi')") + )) + + result <- rlm$forward( + list(question = "What is pi?"), + .llm = mock_llm + ) + + metadata <- result$metadata[[1]] + + expect_equal(metadata$model, "rlm") + expect_true("iterations" %in% names(metadata)) + expect_true("max_iterations" %in% names(metadata)) + expect_true("llm_calls" %in% names(metadata)) + expect_true("max_llm_calls" %in% names(metadata)) + expect_true("duration_ms" %in% names(metadata)) + expect_true("repl_history" %in% names(metadata)) + + expect_equal(metadata$max_iterations, 20) + expect_equal(metadata$max_llm_calls, 50) +}) + +# ============================================================================ +# Integration with module() factory +# ============================================================================ + +test_that("module() factory works with type='rlm'", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 5) + rlm <- module( + signature("question -> answer"), + type = "rlm", + runner = runner + ) + + expect_s3_class(rlm, "RLMModule") +}) + +test_that("module() factory requires runner for rlm", { + expect_error( + module( + signature("question -> answer"), + type = "rlm" + ), + "rlm requires a runner" + ) +}) + +# ============================================================================ +# Context Description Tests +# ============================================================================ + +test_that("RLMModule handles various context types", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "text, numbers, df -> answer", + runner = runner + ) + + mock_llm <- create_mock_rlm_llm(list( + list( + reasoning = "Process inputs", + code = "SUBMIT(paste('text:', nchar(.context$text), 'numbers:', length(.context$numbers), 'df rows:', nrow(.context$df)))" + ) + )) + + result <- rlm$forward( + list( + text = "Hello world", + numbers = c(1, 2, 3, 4, 5), + df = data.frame(a = 1:3, b = 4:6) + ), + .llm = mock_llm + ) + + expect_true(grepl("text: 11", result$output[[1]]$answer)) + expect_true(grepl("numbers: 5", result$output[[1]]$answer)) + expect_true(grepl("df rows: 3", result$output[[1]]$answer)) +}) + +# ============================================================================ +# RLM Tools Unit Tests +# ============================================================================ + +test_that("create_rlm_prelude generates valid R code", { + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = FALSE, + custom_tools = list() + ) + + expect_type(prelude, "character") + # Parse should succeed + expect_no_error(parse(text = prelude)) +}) + +test_that("create_rlm_prelude includes SUBMIT function", { + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = FALSE, + custom_tools = list() + ) + + expect_true(grepl("SUBMIT <- function", prelude)) +}) + +test_that("create_rlm_prelude includes peek function", { + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = FALSE, + custom_tools = list() + ) + + expect_true(grepl("peek <- function", prelude)) +}) + +test_that("create_rlm_prelude includes search function", { + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = FALSE, + custom_tools = list() + ) + + expect_true(grepl("search <- function", prelude)) +}) + +test_that("create_rlm_prelude includes rlm_query when sub_lm enabled", { + prelude_with <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = TRUE, + custom_tools = list() + ) + + prelude_without <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = FALSE, + custom_tools = list() + ) + + # With sub_lm: should have working rlm_query + expect_true(grepl("rlm_query_request", prelude_with)) + + # Without sub_lm: should have disabled rlm_query + expect_true(grepl("Recursive LLM queries are disabled", prelude_without)) +}) + +test_that("create_rlm_prelude includes custom tools", { + custom_tools <- list( + my_tool = function(x) x * 2, + another_tool = function(a, b) a + b + ) + + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = FALSE, + custom_tools = custom_tools + ) + + expect_true(grepl("my_tool <-", prelude)) + expect_true(grepl("another_tool <-", prelude)) +}) + +test_that("is_rlm_final detects rlm_final class", { + # Regular value + expect_false(dsprrr:::is_rlm_final(42)) + expect_false(dsprrr:::is_rlm_final("hello")) + + # rlm_final value + final <- structure("answer", class = c("rlm_final", "character")) + expect_true(dsprrr:::is_rlm_final(final)) + + # Via attribute + final2 <- "answer" + attr(final2, "rlm_final") <- TRUE + expect_true(dsprrr:::is_rlm_final(final2)) +}) + +test_that("extract_rlm_final removes rlm_final class", { + final <- structure("answer", class = c("rlm_final", "character")) + attr(final, "rlm_final") <- TRUE + + extracted <- dsprrr:::extract_rlm_final(final) + + expect_false(dsprrr:::is_rlm_final(extracted)) + expect_equal(extracted, "answer") + expect_null(attr(extracted, "rlm_final")) +}) + +# ============================================================================ +# Error Handling and Edge Case Tests +# ============================================================================ + +test_that("RLMModule warns when max_iterations reached without SUBMIT", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner, + max_iterations = 2 + ) + + # LLM never calls SUBMIT + mock_llm <- create_mock_rlm_llm(list( + list(reasoning = "Step 1", code = "x <- 1"), + list(reasoning = "Step 2", code = "y <- 2") + )) + + expect_warning( + rlm$forward(list(question = "test"), .llm = mock_llm), + "reached max_iterations" + ) +}) + +test_that("RLMModule handles LLM response with missing code", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + # Mock LLM that returns invalid response (missing code) + mock_llm <- list( + clone = function() mock_llm, + chat_structured = function(prompt, type, ...) { + list(reasoning = "Thinking") + # Missing 'code' field + }, + chat = function(prompt, ...) "fallback" + ) + + expect_error( + rlm$forward(list(question = "test"), .llm = mock_llm), + "invalid" + ) +}) + +test_that("RLMModule handles LLM response with non-string code", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + rlm <- rlm_module( + "question -> answer", + runner = runner + ) + + # Mock LLM that returns code as number + mock_llm <- list( + clone = function() mock_llm, + chat_structured = function(prompt, type, ...) { + list(reasoning = "Thinking", code = 123) + }, + chat = function(prompt, ...) "fallback" + ) + + expect_error( + rlm$forward(list(question = "test"), .llm = mock_llm), + "invalid" + ) +}) + +# ============================================================================ +# rlm_query Batch Tests +# ============================================================================ + +test_that("is_rlm_query_request detects query requests", { + regular <- 42 + expect_false(dsprrr:::is_rlm_query_request(regular)) + + request <- structure( + list(query = "test", context = NULL, batch = FALSE), + class = "rlm_query_request" + ) + expect_true(dsprrr:::is_rlm_query_request(request)) +}) + +test_that("rlm_query_batch generates batch request marker", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + + # Execute prelude in subprocess to test rlm_query_batch + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = TRUE, + custom_tools = list() + ) + + result <- runner$execute( + paste0(prelude, "\nrlm_query_batch(c('q1', 'q2'))"), + context = list() + ) + + expect_true(result$success) + expect_s3_class(result$result, "rlm_query_request") + expect_true(result$result$batch) + expect_equal(result$result$queries, c("q1", "q2")) +}) + +test_that("rlm_query_batch validates queries is character", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = TRUE, + custom_tools = list() + ) + + result <- runner$execute( + paste0(prelude, "\nrlm_query_batch(123)"), + context = list() + ) + + expect_false(result$success) + expect_true(grepl("character vector", result$error)) +}) + +test_that("rlm_query_batch validates slices length", { + skip_if_not_installed("callr") + + runner <- r_code_runner(timeout = 10) + + prelude <- dsprrr:::create_rlm_prelude( + max_llm_calls = 50, + has_sub_lm = TRUE, + custom_tools = list() + ) + + result <- runner$execute( + paste0(prelude, "\nrlm_query_batch(c('q1', 'q2'), slices = c('s1'))"), + context = list() + ) + + expect_false(result$success) + expect_true(grepl("same length", result$error)) +})