diff --git a/src/db/query.rs b/src/db/query.rs index 7ba01f9..9827b31 100644 --- a/src/db/query.rs +++ b/src/db/query.rs @@ -70,6 +70,81 @@ impl CellValue { pub fn display_width(&self) -> usize { unicode_width::UnicodeWidthStr::width(self.display().as_str()) } + + /// Compare two CellValues for sorting purposes. + /// Returns an Ordering suitable for sort operations. + /// NULLs are always sorted last regardless of direction. + pub fn sort_cmp(&self, other: &CellValue) -> std::cmp::Ordering { + use std::cmp::Ordering; + match (self, other) { + // NULLs always last + (CellValue::Null, CellValue::Null) => Ordering::Equal, + (CellValue::Null, _) => Ordering::Greater, + (_, CellValue::Null) => Ordering::Less, + + // Booleans: false < true + (CellValue::Bool(a), CellValue::Bool(b)) => a.cmp(b), + + // Integers + (CellValue::Int16(a), CellValue::Int16(b)) => a.cmp(b), + (CellValue::Int32(a), CellValue::Int32(b)) => a.cmp(b), + (CellValue::Int64(a), CellValue::Int64(b)) => a.cmp(b), + + // Cross-integer comparison: promote to i64 + (CellValue::Int16(a), CellValue::Int32(b)) => (*a as i64).cmp(&(*b as i64)), + (CellValue::Int32(a), CellValue::Int16(b)) => (*a as i64).cmp(&(*b as i64)), + (CellValue::Int16(a), CellValue::Int64(b)) => (*a as i64).cmp(b), + (CellValue::Int64(a), CellValue::Int16(b)) => a.cmp(&(*b as i64)), + (CellValue::Int32(a), CellValue::Int64(b)) => (*a as i64).cmp(b), + (CellValue::Int64(a), CellValue::Int32(b)) => a.cmp(&(*b as i64)), + + // Floats + (CellValue::Float32(a), CellValue::Float32(b)) => { + a.partial_cmp(b).unwrap_or(Ordering::Equal) + } + (CellValue::Float64(a), CellValue::Float64(b)) => { + a.partial_cmp(b).unwrap_or(Ordering::Equal) + } + (CellValue::Float32(a), CellValue::Float64(b)) => { + (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal) + } + (CellValue::Float64(a), CellValue::Float32(b)) => { + a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal) + } + + // Numeric vs float: promote int to f64 + (CellValue::Int16(a), CellValue::Float64(b)) => { + (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal) + } + (CellValue::Float64(a), CellValue::Int16(b)) => { + a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal) + } + (CellValue::Int32(a), CellValue::Float64(b)) => { + (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal) + } + (CellValue::Float64(a), CellValue::Int32(b)) => { + a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal) + } + (CellValue::Int64(a), CellValue::Float64(b)) => { + (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal) + } + (CellValue::Float64(a), CellValue::Int64(b)) => { + a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal) + } + + // Text + (CellValue::Text(a), CellValue::Text(b)) => a.cmp(b), + + // Dates and times + (CellValue::Date(a), CellValue::Date(b)) => a.cmp(b), + (CellValue::Time(a), CellValue::Time(b)) => a.cmp(b), + (CellValue::DateTime(a), CellValue::DateTime(b)) => a.cmp(b), + (CellValue::TimestampTz(a), CellValue::TimestampTz(b)) => a.cmp(b), + + // Fallback: compare display strings + (a, b) => a.display().cmp(&b.display()), + } + } } #[allow(dead_code)] @@ -292,6 +367,105 @@ mod tests { assert_eq!(r.error.unwrap(), "bad query"); assert!(r.rows.is_empty()); } + + // --- CellValue sort_cmp --- + + #[test] + fn test_sort_cmp_nulls_last() { + use std::cmp::Ordering; + assert_eq!(CellValue::Null.sort_cmp(&CellValue::Null), Ordering::Equal); + assert_eq!( + CellValue::Null.sort_cmp(&CellValue::Int32(1)), + Ordering::Greater + ); + assert_eq!( + CellValue::Int32(1).sort_cmp(&CellValue::Null), + Ordering::Less + ); + } + + #[test] + fn test_sort_cmp_integers() { + use std::cmp::Ordering; + assert_eq!( + CellValue::Int32(1).sort_cmp(&CellValue::Int32(2)), + Ordering::Less + ); + assert_eq!( + CellValue::Int32(5).sort_cmp(&CellValue::Int32(5)), + Ordering::Equal + ); + assert_eq!( + CellValue::Int64(100).sort_cmp(&CellValue::Int64(50)), + Ordering::Greater + ); + } + + #[test] + fn test_sort_cmp_cross_integer() { + use std::cmp::Ordering; + assert_eq!( + CellValue::Int16(10).sort_cmp(&CellValue::Int32(20)), + Ordering::Less + ); + assert_eq!( + CellValue::Int32(30).sort_cmp(&CellValue::Int64(30)), + Ordering::Equal + ); + } + + #[test] + fn test_sort_cmp_floats() { + use std::cmp::Ordering; + assert_eq!( + CellValue::Float64(1.5).sort_cmp(&CellValue::Float64(2.5)), + Ordering::Less + ); + assert_eq!( + CellValue::Float32(3.0).sort_cmp(&CellValue::Float64(3.0)), + Ordering::Equal + ); + } + + #[test] + fn test_sort_cmp_text() { + use std::cmp::Ordering; + assert_eq!( + CellValue::Text("apple".into()).sort_cmp(&CellValue::Text("banana".into())), + Ordering::Less + ); + assert_eq!( + CellValue::Text("zebra".into()).sort_cmp(&CellValue::Text("aardvark".into())), + Ordering::Greater + ); + } + + #[test] + fn test_sort_cmp_booleans() { + use std::cmp::Ordering; + assert_eq!( + CellValue::Bool(false).sort_cmp(&CellValue::Bool(true)), + Ordering::Less + ); + } + + #[test] + fn test_sort_stable_with_nulls() { + let mut values = vec![ + CellValue::Int32(3), + CellValue::Null, + CellValue::Int32(1), + CellValue::Null, + CellValue::Int32(2), + ]; + values.sort_by(|a, b| a.sort_cmp(b)); + // Nulls should be at the end + assert_eq!(values[0].display(), "1"); + assert_eq!(values[1].display(), "2"); + assert_eq!(values[2].display(), "3"); + assert_eq!(values[3].display(), "NULL"); + assert_eq!(values[4].display(), "NULL"); + } } fn extract_value(row: &Row, idx: usize, pg_type: &Type) -> CellValue { diff --git a/src/ui/app.rs b/src/ui/app.rs index 72dd4ea..ede0429 100644 --- a/src/ui/app.rs +++ b/src/ui/app.rs @@ -127,6 +127,9 @@ pub struct App { pub result_scroll_y: usize, pub result_selected_row: usize, pub result_selected_col: usize, + pub result_sort_column: Option, + pub result_sort_ascending: bool, + pub result_sort_indices: Vec, // Toasts pub toasts: Vec, @@ -298,6 +301,9 @@ impl App { result_scroll_y: 0, result_selected_row: 0, result_selected_col: 0, + result_sort_column: None, + result_sort_ascending: true, + result_sort_indices: Vec::new(), toasts: Vec::new(), is_loading: false, @@ -918,12 +924,16 @@ impl App { self.focus = Focus::ExportPicker; } } + KeyCode::Char('s') if !key.modifiers.contains(KeyModifiers::CONTROL) => { + self.toggle_column_sort(); + } KeyCode::Char('[') if key.modifiers.contains(KeyModifiers::CONTROL) => { if self.current_result > 0 { self.current_result -= 1; self.result_selected_row = 0; self.result_selected_col = 0; self.result_scroll_y = 0; + self.clear_sort(); } } KeyCode::Char(']') if key.modifiers.contains(KeyModifiers::CONTROL) => { @@ -932,6 +942,7 @@ impl App { self.result_selected_row = 0; self.result_selected_col = 0; self.result_scroll_y = 0; + self.clear_sort(); } } _ => {} @@ -1503,6 +1514,7 @@ impl App { self.current_result = self.results.len() - 1; self.result_selected_row = 0; self.result_selected_col = 0; + self.clear_sort(); } else { self.set_status("Not connected to database".to_string(), StatusType::Error); } @@ -1510,6 +1522,77 @@ impl App { Ok(()) } + fn toggle_column_sort(&mut self) { + let col = self.result_selected_col; + if let Some(result) = self.results.get(self.current_result) { + if col >= result.columns.len() { + return; + } + + if self.result_sort_column == Some(col) { + if self.result_sort_ascending { + // Was ascending, switch to descending + self.result_sort_ascending = false; + } else { + // Was descending, clear sort + self.result_sort_column = None; + self.result_sort_ascending = true; + self.result_sort_indices.clear(); + return; + } + } else { + // New column, start ascending + self.result_sort_column = Some(col); + self.result_sort_ascending = true; + } + + self.rebuild_sort_indices(); + } + } + + fn rebuild_sort_indices(&mut self) { + let col = match self.result_sort_column { + Some(c) => c, + None => return, + }; + let result = match self.results.get(self.current_result) { + Some(r) => r, + None => return, + }; + + let ascending = self.result_sort_ascending; + let mut indices: Vec = (0..result.rows.len()).collect(); + indices.sort_by(|&a, &b| { + let val_a = &result.rows[a][col]; + let val_b = &result.rows[b][col]; + let cmp = val_a.sort_cmp(val_b); + if ascending { + cmp + } else { + cmp.reverse() + } + }); + self.result_sort_indices = indices; + } + + fn clear_sort(&mut self) { + self.result_sort_column = None; + self.result_sort_ascending = true; + self.result_sort_indices.clear(); + } + + /// Map a display row index to the actual row index, accounting for sort order. + pub fn sorted_row_index(&self, display_idx: usize) -> usize { + if self.result_sort_indices.is_empty() { + display_idx + } else { + self.result_sort_indices + .get(display_idx) + .copied() + .unwrap_or(display_idx) + } + } + fn copy_selected_cell(&mut self) { if let Some(result) = self.results.get(self.current_result) { if let Some(row) = result.rows.get(self.result_selected_row) { diff --git a/src/ui/components.rs b/src/ui/components.rs index fb06b8a..326490f 100644 --- a/src/ui/components.rs +++ b/src/ui/components.rs @@ -693,12 +693,21 @@ fn draw_result_table(frame: &mut Frame, app: &App, result: &crate::db::QueryResu }) .collect(); - // Create header + // Create header with sort indicators let header_cells: Vec = result .columns .iter() .enumerate() .map(|(i, col)| { + let sort_indicator = if app.result_sort_column == Some(i) { + if app.result_sort_ascending { + " \u{25B2}" // ▲ + } else { + " \u{25BC}" // ▼ + } + } else { + "" + }; let style = if i == app.result_selected_col { Style::default() .fg(theme.text_accent) @@ -708,7 +717,7 @@ fn draw_result_table(frame: &mut Frame, app: &App, result: &crate::db::QueryResu .fg(theme.text_primary) .add_modifier(Modifier::BOLD) }; - Cell::from(col.name.clone()).style(style) + Cell::from(format!("{}{}", col.name, sort_indicator)).style(style) }) .collect(); @@ -716,17 +725,20 @@ fn draw_result_table(frame: &mut Frame, app: &App, result: &crate::db::QueryResu .style(Style::default().bg(theme.bg_secondary)) .height(1); - // Create rows + // Create rows using sorted indices when active let visible_height = area.height.saturating_sub(2) as usize; let start_row = app.result_scroll_y; + let has_sort = !app.result_sort_indices.is_empty(); + + let rows: Vec = (start_row..result.rows.len().min(start_row + visible_height)) + .map(|display_idx| { + let actual_idx = if has_sort { + app.sorted_row_index(display_idx) + } else { + display_idx + }; + let row = &result.rows[actual_idx]; - let rows: Vec = result - .rows - .iter() - .enumerate() - .skip(start_row) - .take(visible_height) - .map(|(row_idx, row)| { let cells: Vec = row .iter() .enumerate() @@ -734,7 +746,7 @@ fn draw_result_table(frame: &mut Frame, app: &App, result: &crate::db::QueryResu let display = cell.display(); let truncated: String = display.chars().take(40).collect(); - let style = if row_idx == app.result_selected_row { + let style = if display_idx == app.result_selected_row { if col_idx == app.result_selected_col { Style::default() .bg(theme.bg_highlight) @@ -1317,6 +1329,7 @@ fn draw_help_overlay(frame: &mut Frame, app: &App) { " Arrow keys Navigate cells", " Esc Back to editor", " Ctrl+C Copy cell value", + " s Sort by column", " Ctrl+S Export results", " Ctrl+[/] Prev/Next result set", " PageUp/Down Scroll results",