diff --git a/r/sedonadb/NAMESPACE b/r/sedonadb/NAMESPACE index 11a141b96..a7f631bf0 100644 --- a/r/sedonadb/NAMESPACE +++ b/r/sedonadb/NAMESPACE @@ -45,11 +45,14 @@ export(sd_expr_factory) export(sd_expr_literal) export(sd_expr_negative) export(sd_expr_scalar_function) +export(sd_filter) export(sd_preview) export(sd_read_parquet) export(sd_register_udf) +export(sd_select) export(sd_sql) export(sd_to_view) +export(sd_transmute) export(sd_view) export(sd_write_parquet) export(sedonadb_adbc) diff --git a/r/sedonadb/R/000-wrappers.R b/r/sedonadb/R/000-wrappers.R index 6cd4654f2..40bb92637 100644 --- a/r/sedonadb/R/000-wrappers.R +++ b/r/sedonadb/R/000-wrappers.R @@ -212,6 +212,16 @@ class(`InternalContext`) <- c( } } +`InternalDataFrame_filter` <- function(self) { + function(`exprs_sexp`) { + .savvy_wrap_InternalDataFrame(.Call( + savvy_InternalDataFrame_filter__impl, + `self`, + `exprs_sexp` + )) + } +} + `InternalDataFrame_limit` <- function(self) { function(`n`) { .savvy_wrap_InternalDataFrame(.Call(savvy_InternalDataFrame_limit__impl, `self`, `n`)) @@ -224,6 +234,16 @@ class(`InternalContext`) <- c( } } +`InternalDataFrame_select` <- function(self) { + function(`exprs_sexp`) { + .savvy_wrap_InternalDataFrame(.Call( + savvy_InternalDataFrame_select__impl, + `self`, + `exprs_sexp` + )) + } +} + `InternalDataFrame_select_indices` <- function(self) { function(`names`, `indices`) { .savvy_wrap_InternalDataFrame(.Call( @@ -316,10 +336,12 @@ class(`InternalContext`) <- c( e$`collect` <- `InternalDataFrame_collect`(ptr) e$`compute` <- `InternalDataFrame_compute`(ptr) e$`count` <- `InternalDataFrame_count`(ptr) + e$`filter` <- `InternalDataFrame_filter`(ptr) e$`limit` <- `InternalDataFrame_limit`(ptr) e$`primary_geometry_column_index` <- `InternalDataFrame_primary_geometry_column_index`( ptr ) + e$`select` <- `InternalDataFrame_select`(ptr) e$`select_indices` <- `InternalDataFrame_select_indices`(ptr) e$`show` <- `InternalDataFrame_show`(ptr) e$`to_arrow_schema` <- `InternalDataFrame_to_arrow_schema`(ptr) diff --git a/r/sedonadb/R/dataframe.R b/r/sedonadb/R/dataframe.R index fefc3a3d2..464210fc1 100644 --- a/r/sedonadb/R/dataframe.R +++ b/r/sedonadb/R/dataframe.R @@ -80,7 +80,7 @@ as_sedonadb_dataframe.datafusion_table_provider <- function(x, ..., schema = NUL #' Count rows in a DataFrame #' -#' @param .data A sedonadb_dataframe +#' @param .data A sedonadb_dataframe or an object that can be coerced to one. #' #' @returns The number of rows after executing the query #' @export @@ -89,6 +89,7 @@ as_sedonadb_dataframe.datafusion_table_provider <- function(x, ..., schema = NUL #' sd_sql("SELECT 1 as one") |> sd_count() #' sd_count <- function(.data) { + .data <- as_sedonadb_dataframe(.data) .data$df$count() } @@ -193,6 +194,91 @@ sd_preview <- function(.data, n = NULL, ascii = NULL, width = NULL) { invisible(.data) } +#' Keep or drop columns of a SedonaDB DataFrame +#' +#' @inheritParams sd_count +#' @param ... One or more bare names. Evaluated like [dplyr::select()]. +#' +#' @returns An object of class sedonadb_dataframe +#' @export +#' +#' @examples +#' data.frame(x = 1:10, y = letters[1:10]) |> sd_select(x) +#' +sd_select <- function(.data, ...) { + .data <- as_sedonadb_dataframe(.data) + schema <- nanoarrow::infer_nanoarrow_schema(.data) + ptype <- nanoarrow::infer_nanoarrow_ptype(schema) + loc <- tidyselect::eval_select(rlang::expr(c(...)), data = ptype) + + df <- .data$df$select_indices(names(loc), loc - 1L) + new_sedonadb_dataframe(.data$ctx, df) +} + +#' Create, modify, and delete columns of a SedonaDB DataFrame +#' +#' @inheritParams sd_count +#' @param ... Named expressions for new columns to create. These are evaluated +#' in the same way as [dplyr::transmute()] except does not support extra +#' dplyr features such as `across()` or `.by`. +#' +#' @returns An object of class sedonadb_dataframe +#' @export +#' +#' @examples +#' data.frame(x = 1:10) |> +#' sd_transmute(y = x + 1L) +#' +sd_transmute <- function(.data, ...) { + .data <- as_sedonadb_dataframe(.data) + expr_quos <- rlang::enquos(...) + env <- parent.frame() + + expr_ctx <- sd_expr_ctx(infer_nanoarrow_schema(.data), env) + r_exprs <- expr_quos |> rlang::quos_auto_name() |> lapply(rlang::quo_get_expr) + sd_exprs <- lapply(r_exprs, sd_eval_expr, expr_ctx = expr_ctx, env = env) + + # Ensure inputs are given aliases to account for the expected column name + exprs_names <- names(r_exprs) + for (i in seq_along(sd_exprs)) { + name <- exprs_names[i] + if (!is.na(name) && name != "") { + sd_exprs[[i]] <- sd_expr_alias(sd_exprs[[i]], name, expr_ctx$factory) + } + } + + df <- .data$df$select(sd_exprs) + new_sedonadb_dataframe(.data$ctx, df) +} + +#' Keep rows of a SedonaDB DataFrame that match a condition +#' +#' @inheritParams sd_count +#' @param ... Unnamed expressions for filter conditions. These are evaluated +#' in the same way as [dplyr::filter()] except does not support extra +#' dplyr features such as `across()` or `.by`. +#' +#' @returns An object of class sedonadb_dataframe +#' @export +#' +#' @examples +#' data.frame(x = 1:10) |> sd_filter(x > 5) +#' +sd_filter <- function(.data, ...) { + .data <- as_sedonadb_dataframe(.data) + rlang::check_dots_unnamed() + + expr_quos <- rlang::enquos(...) + env <- parent.frame() + + expr_ctx <- sd_expr_ctx(infer_nanoarrow_schema(.data), env) + r_exprs <- expr_quos |> lapply(rlang::quo_get_expr) + sd_exprs <- lapply(r_exprs, sd_eval_expr, expr_ctx = expr_ctx, env = env) + + df <- .data$df$filter(sd_exprs) + new_sedonadb_dataframe(.data$ctx, df) +} + #' Write DataFrame to (Geo)Parquet files #' #' Write this DataFrame to one or more (Geo)Parquet files. For input that contains @@ -246,6 +332,8 @@ sd_write_parquet <- function( geoparquet_version = "1.0", overwrite_bbox_columns = FALSE ) { + .data <- as_sedonadb_dataframe(.data) + # Determine single_file_output default based on path and partition_by if (is.null(single_file_output)) { single_file_output <- length(partition_by) == 0 && grepl("\\.parquet$", path) diff --git a/r/sedonadb/R/expression.R b/r/sedonadb/R/expression.R index cca754a2c..0d97c8989 100644 --- a/r/sedonadb/R/expression.R +++ b/r/sedonadb/R/expression.R @@ -138,6 +138,7 @@ print.SedonaDBExpr <- function(x, ...) { #' #' @param expr An R expression (e.g., the result of `quote()`). #' @param expr_ctx An `sd_expr_ctx()` +#' @param env An evaluation environment. Defaults to the calling environment. #' #' @returns A `SedonaDBExpr` #' @noRd diff --git a/r/sedonadb/man/sd_compute.Rd b/r/sedonadb/man/sd_compute.Rd index 97590fd67..ecf7de0b9 100644 --- a/r/sedonadb/man/sd_compute.Rd +++ b/r/sedonadb/man/sd_compute.Rd @@ -10,7 +10,7 @@ sd_compute(.data) sd_collect(.data, ptype = NULL) } \arguments{ -\item{.data}{A sedonadb_dataframe} +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} \item{ptype}{The target R object. See \link[nanoarrow:convert_array_stream]{nanoarrow::convert_array_stream}.} } diff --git a/r/sedonadb/man/sd_count.Rd b/r/sedonadb/man/sd_count.Rd index c93b9d53d..fb48dd285 100644 --- a/r/sedonadb/man/sd_count.Rd +++ b/r/sedonadb/man/sd_count.Rd @@ -7,7 +7,7 @@ sd_count(.data) } \arguments{ -\item{.data}{A sedonadb_dataframe} +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} } \value{ The number of rows after executing the query diff --git a/r/sedonadb/man/sd_filter.Rd b/r/sedonadb/man/sd_filter.Rd new file mode 100644 index 000000000..f5e642347 --- /dev/null +++ b/r/sedonadb/man/sd_filter.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dataframe.R +\name{sd_filter} +\alias{sd_filter} +\title{Keep rows of a SedonaDB DataFrame that match a condition} +\usage{ +sd_filter(.data, ...) +} +\arguments{ +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} + +\item{...}{Unnamed expressions for filter conditions. These are evaluated +in the same way as \code{\link[dplyr:filter]{dplyr::filter()}} except does not support extra +dplyr features such as \code{across()} or \code{.by}.} +} +\value{ +An object of class sedonadb_dataframe +} +\description{ +Keep rows of a SedonaDB DataFrame that match a condition +} +\examples{ +data.frame(x = 1:10) |> sd_filter(x > 5) + +} diff --git a/r/sedonadb/man/sd_preview.Rd b/r/sedonadb/man/sd_preview.Rd index 351dd5a76..c9e09f0af 100644 --- a/r/sedonadb/man/sd_preview.Rd +++ b/r/sedonadb/man/sd_preview.Rd @@ -7,7 +7,7 @@ sd_preview(.data, n = NULL, ascii = NULL, width = NULL) } \arguments{ -\item{.data}{A sedonadb_dataframe} +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} \item{n}{The number of rows to preview. Use \code{Inf} to preview all rows. Defaults to \code{getOption("pillar.print_max")}.} diff --git a/r/sedonadb/man/sd_select.Rd b/r/sedonadb/man/sd_select.Rd new file mode 100644 index 000000000..9ef542861 --- /dev/null +++ b/r/sedonadb/man/sd_select.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dataframe.R +\name{sd_select} +\alias{sd_select} +\title{Keep or drop columns of a SedonaDB DataFrame} +\usage{ +sd_select(.data, ...) +} +\arguments{ +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} + +\item{...}{One or more bare names. Evaluated like \code{\link[dplyr:select]{dplyr::select()}}.} +} +\value{ +An object of class sedonadb_dataframe +} +\description{ +Keep or drop columns of a SedonaDB DataFrame +} +\examples{ +data.frame(x = 1:10, y = letters[1:10]) |> sd_select(x) + +} diff --git a/r/sedonadb/man/sd_to_view.Rd b/r/sedonadb/man/sd_to_view.Rd index 5c3ab020c..dce288498 100644 --- a/r/sedonadb/man/sd_to_view.Rd +++ b/r/sedonadb/man/sd_to_view.Rd @@ -7,7 +7,7 @@ sd_to_view(.data, table_ref, overwrite = FALSE) } \arguments{ -\item{.data}{A sedonadb_dataframe} +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} \item{table_ref}{The name of the view reference} diff --git a/r/sedonadb/man/sd_transmute.Rd b/r/sedonadb/man/sd_transmute.Rd new file mode 100644 index 000000000..750e3fe81 --- /dev/null +++ b/r/sedonadb/man/sd_transmute.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dataframe.R +\name{sd_transmute} +\alias{sd_transmute} +\title{Create, modify, and delete columns of a SedonaDB DataFrame} +\usage{ +sd_transmute(.data, ...) +} +\arguments{ +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} + +\item{...}{Named expressions for new columns to create. These are evaluated +in the same way as \code{\link[dplyr:transmute]{dplyr::transmute()}} except does not support extra +dplyr features such as \code{across()} or \code{.by}.} +} +\value{ +An object of class sedonadb_dataframe +} +\description{ +Create, modify, and delete columns of a SedonaDB DataFrame +} +\examples{ +data.frame(x = 1:10) |> + sd_transmute(y = x + 1L) + +} diff --git a/r/sedonadb/man/sd_write_parquet.Rd b/r/sedonadb/man/sd_write_parquet.Rd index afd483849..17aa01047 100644 --- a/r/sedonadb/man/sd_write_parquet.Rd +++ b/r/sedonadb/man/sd_write_parquet.Rd @@ -15,7 +15,7 @@ sd_write_parquet( ) } \arguments{ -\item{.data}{A sedonadb_dataframe} +\item{.data}{A sedonadb_dataframe or an object that can be coerced to one.} \item{path}{A filename or directory to which parquet file(s) should be written} diff --git a/r/sedonadb/src/init.c b/r/sedonadb/src/init.c index 0e9efae4b..48cc32900 100644 --- a/r/sedonadb/src/init.c +++ b/r/sedonadb/src/init.c @@ -149,6 +149,11 @@ SEXP savvy_InternalDataFrame_count__impl(SEXP self__) { return handle_result(res); } +SEXP savvy_InternalDataFrame_filter__impl(SEXP self__, SEXP c_arg__exprs_sexp) { + SEXP res = savvy_InternalDataFrame_filter__ffi(self__, c_arg__exprs_sexp); + return handle_result(res); +} + SEXP savvy_InternalDataFrame_limit__impl(SEXP self__, SEXP c_arg__n) { SEXP res = savvy_InternalDataFrame_limit__ffi(self__, c_arg__n); return handle_result(res); @@ -159,6 +164,11 @@ SEXP savvy_InternalDataFrame_primary_geometry_column_index__impl(SEXP self__) { return handle_result(res); } +SEXP savvy_InternalDataFrame_select__impl(SEXP self__, SEXP c_arg__exprs_sexp) { + SEXP res = savvy_InternalDataFrame_select__ffi(self__, c_arg__exprs_sexp); + return handle_result(res); +} + SEXP savvy_InternalDataFrame_select_indices__impl(SEXP self__, SEXP c_arg__names, SEXP c_arg__indices) { @@ -312,10 +322,14 @@ static const R_CallMethodDef CallEntries[] = { (DL_FUNC)&savvy_InternalDataFrame_compute__impl, 2}, {"savvy_InternalDataFrame_count__impl", (DL_FUNC)&savvy_InternalDataFrame_count__impl, 1}, + {"savvy_InternalDataFrame_filter__impl", + (DL_FUNC)&savvy_InternalDataFrame_filter__impl, 2}, {"savvy_InternalDataFrame_limit__impl", (DL_FUNC)&savvy_InternalDataFrame_limit__impl, 2}, {"savvy_InternalDataFrame_primary_geometry_column_index__impl", (DL_FUNC)&savvy_InternalDataFrame_primary_geometry_column_index__impl, 1}, + {"savvy_InternalDataFrame_select__impl", + (DL_FUNC)&savvy_InternalDataFrame_select__impl, 2}, {"savvy_InternalDataFrame_select_indices__impl", (DL_FUNC)&savvy_InternalDataFrame_select_indices__impl, 3}, {"savvy_InternalDataFrame_show__impl", diff --git a/r/sedonadb/src/rust/api.h b/r/sedonadb/src/rust/api.h index fac6258bd..b43df3f52 100644 --- a/r/sedonadb/src/rust/api.h +++ b/r/sedonadb/src/rust/api.h @@ -42,8 +42,10 @@ SEXP savvy_InternalContext_view__ffi(SEXP self__, SEXP c_arg__table_ref); SEXP savvy_InternalDataFrame_collect__ffi(SEXP self__, SEXP c_arg__out); SEXP savvy_InternalDataFrame_compute__ffi(SEXP self__, SEXP c_arg__ctx); SEXP savvy_InternalDataFrame_count__ffi(SEXP self__); +SEXP savvy_InternalDataFrame_filter__ffi(SEXP self__, SEXP c_arg__exprs_sexp); SEXP savvy_InternalDataFrame_limit__ffi(SEXP self__, SEXP c_arg__n); SEXP savvy_InternalDataFrame_primary_geometry_column_index__ffi(SEXP self__); +SEXP savvy_InternalDataFrame_select__ffi(SEXP self__, SEXP c_arg__exprs_sexp); SEXP savvy_InternalDataFrame_select_indices__ffi(SEXP self__, SEXP c_arg__names, SEXP c_arg__indices); SEXP savvy_InternalDataFrame_show__ffi(SEXP self__, SEXP c_arg__ctx, diff --git a/r/sedonadb/src/rust/src/dataframe.rs b/r/sedonadb/src/rust/src/dataframe.rs index e34cee82d..275b1f4f5 100644 --- a/r/sedonadb/src/rust/src/dataframe.rs +++ b/r/sedonadb/src/rust/src/dataframe.rs @@ -21,6 +21,7 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader}; use datafusion::catalog::MemTable; use datafusion::prelude::DataFrame; use datafusion_common::Column; +use datafusion_expr::utils::conjunction; use datafusion_expr::{select_expr::SelectExpr, Expr, SortExpr}; use datafusion_ffi::table_provider::FFI_TableProvider; use savvy::{savvy, savvy_err, sexp, IntoExtPtrSexp, Result}; @@ -33,6 +34,7 @@ use std::{iter::zip, ptr::swap_nonoverlapping, sync::Arc}; use tokio::runtime::Runtime; use crate::context::InternalContext; +use crate::expression::SedonaDBExprFactory; use crate::ffi::{import_schema, FFITableProviderR}; use crate::runtime::wait_for_future_captured_r; @@ -311,4 +313,21 @@ impl InternalDataFrame { let inner = self.inner.clone().select(exprs)?; Ok(new_data_frame(inner, self.runtime.clone())) } + + fn select(&self, exprs_sexp: savvy::Sexp) -> savvy::Result { + let exprs = SedonaDBExprFactory::exprs(exprs_sexp)?; + let inner = self.inner.clone().select(exprs)?; + Ok(new_data_frame(inner, self.runtime.clone())) + } + + fn filter(&self, exprs_sexp: savvy::Sexp) -> savvy::Result { + let exprs = SedonaDBExprFactory::exprs(exprs_sexp)?; + let inner = if let Some(single_filter) = conjunction(exprs) { + self.inner.clone().filter(single_filter)? + } else { + self.inner.clone() + }; + + Ok(new_data_frame(inner, self.runtime.clone())) + } } diff --git a/r/sedonadb/src/rust/src/expression.rs b/r/sedonadb/src/rust/src/expression.rs index 0add4b535..e0753fd28 100644 --- a/r/sedonadb/src/rust/src/expression.rs +++ b/r/sedonadb/src/rust/src/expression.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use datafusion_common::{Column, ScalarValue}; +use datafusion_common::{Column, Result, ScalarValue}; use datafusion_expr::{ expr::{AggregateFunction, FieldMetadata, NullTreatment, ScalarFunction}, BinaryExpr, Cast, Expr, Operator, @@ -175,7 +175,7 @@ impl SedonaDBExprFactory { } impl SedonaDBExprFactory { - fn exprs(exprs_sexp: savvy::Sexp) -> savvy::Result> { + pub fn exprs(exprs_sexp: savvy::Sexp) -> savvy::Result> { savvy::ListSexp::try_from(exprs_sexp)? .iter() .map(|(_, item)| -> savvy::Result { diff --git a/r/sedonadb/tests/testthat/test-dataframe.R b/r/sedonadb/tests/testthat/test-dataframe.R index 0bc3c4c1c..3d96c2398 100644 --- a/r/sedonadb/tests/testthat/test-dataframe.R +++ b/r/sedonadb/tests/testthat/test-dataframe.R @@ -286,3 +286,71 @@ test_that("sd_write_parquet validates geoparquet_version parameter", { "geoparquet_version must be" ) }) + +test_that("sd_select() works with dplyr-like select syntax", { + skip_if_not_installed("tidyselect") + + df_in <- data.frame(one = 1, two = "two", THREE = 3.0) + + expect_identical( + df_in |> sd_select(2:3) |> sd_collect(), + data.frame(two = "two", THREE = 3.0) + ) + + expect_identical( + df_in |> sd_select(three_renamed = THREE, one) |> sd_collect(), + data.frame(three_renamed = 3.0, one = 1) + ) + + expect_identical( + df_in |> sd_select(TWO = two) |> sd_collect(), + data.frame(TWO = "two") + ) +}) + +test_that("sd_transmute() works with dplyr-like transmute syntax", { + df_in <- data.frame(x = 1:10) + + # checks that (1) unnamed inputs like `x` are named `x` in the output + # and (2) named inputs are given an alias and (3) expressions are + # translated. + expect_identical( + df_in |> sd_transmute(x, y = x + 1L) |> sd_collect(), + data.frame(x = 1:10, y = 2:11) + ) + + # Check that the calling environment is handled + integer_one <- 1L + expect_identical( + df_in |> sd_transmute(x, y = x + integer_one) |> sd_collect(), + data.frame(x = 1:10, y = 2:11) + ) +}) + +test_that("sd_filter() works with dplyr-like filter syntax", { + df_in <- data.frame(x = 1:10) + + # Zero conditions + expect_identical( + df_in |> sd_filter() |> sd_collect(), + df_in + ) + + # One condition + expect_identical( + df_in |> sd_filter(x >= 5) |> sd_collect(), + data.frame(x = 5:10) + ) + + # Multiple conditions + expect_identical( + df_in |> sd_filter(x >= 5, x >= 6) |> sd_collect(), + data.frame(x = 6:10) + ) + + # Ensure null handling of conditions is dplyr-like (drops nulls) + expect_identical( + df_in |> sd_filter(x >= NA_integer_) |> sd_collect(), + data.frame(x = integer()) + ) +})