From d105e782bf78368e473e9262c92d407aebe2a226 Mon Sep 17 00:00:00 2001 From: James Vaughan Date: Sun, 2 Nov 2025 16:54:25 -0500 Subject: [PATCH] Added error messages Removed unwraps and added better message. Added a unit test to make sure that we generate an expected .mtx file. --- .gitignore | 1 + src/main.rs | 121 +++++++++++++++++++++++++++++------------ test/test.csv | 8 +++ test/test_expected.mtx | Bin 0 -> 120 bytes 4 files changed, 94 insertions(+), 36 deletions(-) create mode 100644 test/test.csv create mode 100644 test/test_expected.mtx diff --git a/.gitignore b/.gitignore index ea8c4bf..e476569 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/test/test_output.mtx diff --git a/src/main.rs b/src/main.rs index 7483ef4..2f8124e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,3 @@ -use core::panic; use std::env; use std::collections::HashSet; use std::io::{BufWriter, Write}; @@ -9,19 +8,52 @@ use rayon::prelude::*; // For parallel processing /// The main function parses command-line arguments, processes the input CSV file, /// optionally uses a zones CSV file, and writes the output in MTX format. -fn main() { +fn main() -> std::io::Result<()> { let arg: Vec = env::args().collect(); if arg.len() < 3 { println!("Usage: csv_to_mtx [zones.csv]"); - return; + return Ok(()); } - let data = read_csv(&arg[1]); - let all_zones = get_all_zones(&arg, &data); + let zones_file = if arg.len() > 3 { + Some(&arg[3] as &str) + } else { + None + }; + convert_csv_to_mtx(&arg[1], &arg[2], zones_file) +} + +/// Converts the input CSV file to MTX format and writes it to the output file. +/// +fn convert_csv_to_mtx( + input_file: &str, + output_file: &str, + zones_file: Option<&str>, +) -> std::io::Result<()> { + let data = match read_csv(input_file) { + Ok(data) => data, + Err(e) => { + eprintln!("Error reading CSV file: {}", e); + return Err(e); + } + }; + let all_zones = match get_all_zones(zones_file, &data) { + Ok(zones) => zones, + Err(e) => { + eprintln!("Error reading zones file: {}", e); + return Err(e); + } + }; println!("Found {} zones", all_zones.len()); let matrix = build_matrix(&data, &all_zones); - write_mtx_file(&arg[2], &all_zones, &matrix); + match write_mtx_file(output_file, &all_zones, &matrix) { + Err(e) => { + eprintln!("Error writing MTX file: {}", e); + Err(e) + } + _ => Ok(()), + } } /// Reads the input CSV file and extracts the data as a vector of tuples containing @@ -34,13 +66,9 @@ fn main() { /// /// # Returns /// A vector of tuples `(i32, i32, f32)` representing the origin, destination, and value. -fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> { - let file = match File::open(input_file) { - Ok(f) => f, - Err(e) => { - panic!("Error opening file {input_file}: {e}"); - } - }; +fn read_csv(input_file: &str) -> std::io::Result> { + let file = File::open(input_file)?; + let mut rdr = csv::ReaderBuilder::new() .has_headers(false) .from_reader(file); @@ -73,13 +101,13 @@ fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> { } } - data + Ok(data) } else { // Rectangular format - pass the first record and remaining iterator - read_rectangular_csv_from_records(first_record, records) + Ok(read_rectangular_csv_from_records(first_record, records)) } } else { - Vec::new() + Ok(Vec::new()) } } @@ -130,21 +158,21 @@ fn read_rectangular_csv_from_records( /// or by extracting unique origins and destinations from the input data. /// /// # Arguments -/// * `arg` - The command-line arguments. +/// * `zones_file` - Optional path to the zones CSV file. /// * `data` - The vector of tuples `(i32, i32, f32)` representing the input data. /// /// # Returns /// A sorted vector of unique zone numbers. -fn get_all_zones(arg: &[String], data: &[(i32, i32, f32)]) -> Vec { - if arg.len() > 3 { - let zone_file = File::open(&arg[3]).unwrap(); +fn get_all_zones(zones_file: Option<&str>, data: &[(i32, i32, f32)]) -> std::io::Result> { + if let Some(zone_file) = zones_file { + let zone_file = File::open(zone_file)?; let mut zone_rdr = csv::Reader::from_reader(zone_file); let mut zones: Vec = zone_rdr .records() .filter_map(|result| result.ok()?.get(0)?.parse().ok()) .collect(); zones.sort_unstable(); - zones + Ok(zones) } else { let zones: HashSet = data .par_iter() @@ -152,7 +180,7 @@ fn get_all_zones(arg: &[String], data: &[(i32, i32, f32)]) -> Vec { .collect(); let mut zones: Vec = zones.into_iter().collect(); zones.sort_unstable(); - zones + Ok(zones) } } @@ -194,8 +222,8 @@ fn build_matrix(data: &[(i32, i32, f32)], all_zones: &[i32]) -> Vec { /// /// # Panics /// This function will panic if it fails to create or write to the output file. -fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) { - let output_file = File::create(output_file_name).unwrap(); +fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) -> std::io::Result<()> { + let output_file = File::create(output_file_name)?; let mut writer: Box = if output_file_name.ends_with(".gz") { Box::new(BufWriter::new(GzEncoder::new(output_file, Compression::default()))) } else { @@ -204,26 +232,26 @@ fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) { let zone_count = all_zones.len() as i32; - writer.write_all(&0xC4D4F1B2u32.to_le_bytes()).unwrap(); // Magic Number - writer.write_all(&1i32.to_le_bytes()).unwrap(); // Version Number - writer.write_all(&1i32.to_le_bytes()).unwrap(); // Type - writer.write_all(&2i32.to_le_bytes()).unwrap(); // Dimensions - writer.write_all(&zone_count.to_le_bytes()).unwrap(); // Index size for origin - writer.write_all(&zone_count.to_le_bytes()).unwrap(); // Index size for destination + writer.write_all(&0xC4D4F1B2u32.to_le_bytes())?; // Magic Number + writer.write_all(&1i32.to_le_bytes())?; // Version Number + writer.write_all(&1i32.to_le_bytes())?; // Type + writer.write_all(&2i32.to_le_bytes())?; // Dimensions + writer.write_all(&zone_count.to_le_bytes())?; // Index size for origin + writer.write_all(&zone_count.to_le_bytes())?; // Index size for destination let is_little_endian = cfg!(target_endian = "little"); if is_little_endian { // Write all origin zone numbers in a single call let origin_zone_bytes: &[u8] = bytemuck::cast_slice(all_zones); - writer.write_all(origin_zone_bytes).unwrap(); // Zone Numbers for Origin + writer.write_all(origin_zone_bytes)?; // Zone Numbers for Origin // Write all destination zone numbers in a single call - writer.write_all(origin_zone_bytes).unwrap(); // Zone Numbers for Destination + writer.write_all(origin_zone_bytes)?; // Zone Numbers for Destination // Write all matrix values in a single call let matrix_bytes: &[u8] = bytemuck::cast_slice(matrix); - writer.write_all(matrix_bytes).unwrap(); + writer.write_all(matrix_bytes)?; } else { // Convert all_zones to little-endian @@ -231,14 +259,35 @@ fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) { .par_iter() .flat_map(|&zone| zone.to_le_bytes()) .collect(); - writer.write_all(&origin_zone_bytes).unwrap(); // Zone Numbers for Origin - writer.write_all(&origin_zone_bytes).unwrap(); // Zone Numbers for Destination + writer.write_all(&origin_zone_bytes)?; // Zone Numbers for Origin + writer.write_all(&origin_zone_bytes)?; // Zone Numbers for Destination // Convert matrix to little-endian let matrix_bytes: Vec = matrix .par_iter() .flat_map(|&value| value.to_le_bytes()) .collect(); - writer.write_all(&matrix_bytes).unwrap(); + writer.write_all(&matrix_bytes)?; + } + Ok(()) +} + +// Write a test using test.csv to make sure that it converts to an mtx file +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_csv_to_mtx() -> std::io::Result<()> { + let input_file = "test/test.csv"; + let output_file = "test/test_output.mtx"; + + convert_csv_to_mtx(input_file, output_file, None)?; + + // Compare against a known good output file + let expected_output_file = "test/test_expected.mtx"; + let output_data = std::fs::read(output_file)?; + let expected_data = std::fs::read(expected_output_file)?; + assert_eq!(output_data, expected_data); + Ok(()) } } diff --git a/test/test.csv b/test/test.csv new file mode 100644 index 0000000..5565543 --- /dev/null +++ b/test/test.csv @@ -0,0 +1,8 @@ +Origin,Destination,Value +1,1,0.1 +1,2,0.2 +1,3,0.3 +2,1,1 +2,2,2 +2,3,3 +4,4,0.1 diff --git a/test/test_expected.mtx b/test/test_expected.mtx new file mode 100644 index 0000000000000000000000000000000000000000..b33b76fc24a99d68ca8e7832c8e919588647e977 GIT binary patch literal 120 zcmdn=@yZcK1_lNYW&&asD2DNwaq!QcIb(bFjE~){nKSKxN`Roj9!NL<5fI}CAiV%& CIu1+# literal 0 HcmV?d00001