Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions src/bin/datu/commands/head.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
use anyhow::Result;
use anyhow::bail;
use datu::FileType;
use datu::cli::HeadsOrTails;
use datu::pipeline::RecordBatchReaderSource;
use datu::pipeline::Step;
use datu::pipeline::VecRecordBatchReaderSource;
use datu::pipeline::display::DisplayWriterStep;
use datu::pipeline::read_to_batches;
use datu::pipeline::build_reader;
use datu::pipeline::display::apply_select_and_display;
use datu::pipeline::record_batch_filter::parse_select_step;
use datu::resolve_input_file_type;

/// head command implementation: print the first N lines of an Avro, CSV, Parquet, or ORC file.
pub async fn head(args: HeadsOrTails) -> Result<()> {
let input_file_type = resolve_input_file_type(args.input, &args.input_path)?;
let batches = read_to_batches(
match input_file_type {
FileType::Parquet | FileType::Avro | FileType::Csv | FileType::Orc => {}
_ => bail!("Only Parquet, Avro, CSV, and ORC are supported for head"),
}
// Pass offset=0 when limiting so ORC row selection applies (it requires both offset and limit).
let reader_step = build_reader(
&args.input_path,
input_file_type,
&args.select,
Some(args.number),
Some(0),
args.input_headers,
)?;
apply_select_and_display(
reader_step,
parse_select_step(&args.select),
args.output,
args.sparse,
args.output_headers.unwrap_or(true),
)
.await?;
let reader_step: RecordBatchReaderSource = Box::new(VecRecordBatchReaderSource::new(batches));
let display_step = DisplayWriterStep {
output_format: args.output,
sparse: args.sparse,
headers: args.output_headers.unwrap_or(true),
};
display_step.execute(reader_step).await.map_err(Into::into)
.await
.map_err(Into::into)
}
2 changes: 1 addition & 1 deletion src/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl ReplPipelineBuilder {
let batches = self.batches.take().ok_or_else(|| {
Error::GenericError("select requires a preceding read in the pipe".to_string())
})?;
let selected = select::select_columns_to_batches(batches, columns).await?;
let selected = select::select_columns_to_batches(batches, columns)?;
self.batches = Some(selected);
Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ pub fn build_reader(
FileType::Csv => Box::new(ReadCsvStep {
path: path.to_string(),
has_header: csv_has_header,
limit,
}),
FileType::Orc => Box::new(ReadOrcStep {
args: ReadArgs {
Expand Down
18 changes: 17 additions & 1 deletion src/pipeline/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use crate::pipeline::batch_write::write_record_batches_with_sink;
pub struct ReadCsvStep {
pub path: String,
pub has_header: Option<bool>,
/// Maximum number of rows to read. None means read all.
pub limit: Option<usize>,
}

impl Source<dyn RecordBatchReader + 'static> for ReadCsvStep {
Expand All @@ -30,12 +32,26 @@ impl Source<dyn RecordBatchReader + 'static> for ReadCsvStep {
})
.map_err(|e| Error::GenericError(e.to_string()))?;

let batches = tokio::task::block_in_place(|| {
let mut batches = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(df.collect())
})
.map_err(|e| Error::GenericError(e.to_string()))?;

if let Some(limit) = self.limit {
let mut result = Vec::new();
let mut remaining = limit;
for batch in batches {
if remaining == 0 {
break;
}
let rows = batch.num_rows().min(remaining);
result.push(batch.slice(0, rows));
remaining -= rows;
}
batches = result;
}

Ok(Box::new(VecRecordBatchReader::new(batches)))
}
}
Expand Down
38 changes: 19 additions & 19 deletions src/pipeline/select.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! DataFusion DataFrame API for column selection.
//! Column selection: ColumnSpec resolution (shared by CLI and REPL), in-memory projection
//! (Arrow), and DataFusion-based read helpers.

use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -103,30 +104,29 @@ pub async fn read_avro_select(
Ok(Box::new(VecRecordBatchReaderSource::new(batches)))
}

/// Applies column selection to record batches using the DataFusion DataFrame API.
/// Returns the selected batches directly (for use when RecordBatchReaderSource is not needed).
/// Resolves ColumnSpec against the schema: Exact uses case-sensitive match, CaseInsensitive uses
/// case-insensitive match.
pub async fn select_columns_to_batches(
/// Applies column selection to record batches using the same resolution and projection
/// as the streaming SelectColumnsStep: resolve_column_specs then Arrow project by indices.
pub fn select_columns_to_batches(
batches: Vec<RecordBatch>,
specs: &[ColumnSpec],
) -> crate::Result<Vec<RecordBatch>> {
if batches.is_empty() {
if batches.is_empty() || specs.is_empty() {
return Ok(batches);
}
let schema = batches[0].schema();
let columns = resolve_column_specs(&schema, specs)?;
let ctx = SessionContext::new();
let col_refs: Vec<&str> = columns.iter().map(String::as_str).collect();
let df = ctx
.read_batches(batches)
.map_err(|e| crate::Error::GenericError(e.to_string()))?;
let df = df
.select_columns(&col_refs)
.map_err(|e| crate::Error::GenericError(e.to_string()))?;
df.collect()
.await
.map_err(|e| crate::Error::GenericError(e.to_string()))
let column_names = resolve_column_specs(schema.as_ref(), specs)?;
let indices: Vec<usize> = column_names
.iter()
.map(|col| {
schema
.index_of(col)
.map_err(|e| crate::Error::GenericError(format!("Column '{col}' not found: {e}")))
})
.collect::<crate::Result<Vec<_>>>()?;
batches
.into_iter()
.map(|batch| batch.project(&indices).map_err(crate::Error::from))
.collect()
}

/// Applies column selection to record batches using the DataFusion DataFrame API.
Expand Down
Loading