Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/target
/test/test_output.mtx
121 changes: 85 additions & 36 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use core::panic;
use std::env;
use std::collections::HashSet;
use std::io::{BufWriter, Write};
Expand All @@ -9,19 +8,52 @@ use rayon::prelude::*; // For parallel processing

/// The main function parses command-line arguments, processes the input CSV file,
/// optionally uses a zones CSV file, and writes the output in MTX format.
fn main() {
fn main() -> std::io::Result<()> {
let arg: Vec<String> = env::args().collect();

if arg.len() < 3 {
println!("Usage: csv_to_mtx <input.csv> <output.mtx/.mtx.gz> [zones.csv]");
return;
return Ok(());
}

let data = read_csv(&arg[1]);
let all_zones = get_all_zones(&arg, &data);
let zones_file = if arg.len() > 3 {
Some(&arg[3] as &str)
} else {
None
};
convert_csv_to_mtx(&arg[1], &arg[2], zones_file)
}

/// Converts the input CSV file to MTX format and writes it to the output file.
///
fn convert_csv_to_mtx(
input_file: &str,
output_file: &str,
zones_file: Option<&str>,
) -> std::io::Result<()> {
let data = match read_csv(input_file) {
Ok(data) => data,
Err(e) => {
eprintln!("Error reading CSV file: {}", e);
return Err(e);
}
};
let all_zones = match get_all_zones(zones_file, &data) {
Ok(zones) => zones,
Err(e) => {
eprintln!("Error reading zones file: {}", e);
return Err(e);
}
};
println!("Found {} zones", all_zones.len());
let matrix = build_matrix(&data, &all_zones);
write_mtx_file(&arg[2], &all_zones, &matrix);
match write_mtx_file(output_file, &all_zones, &matrix) {
Err(e) => {
eprintln!("Error writing MTX file: {}", e);
Err(e)
}
_ => Ok(()),
}
}

/// Reads the input CSV file and extracts the data as a vector of tuples containing
Expand All @@ -34,13 +66,9 @@ fn main() {
///
/// # Returns
/// A vector of tuples `(i32, i32, f32)` representing the origin, destination, and value.
fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> {
let file = match File::open(input_file) {
Ok(f) => f,
Err(e) => {
panic!("Error opening file {input_file}: {e}");
}
};
fn read_csv(input_file: &str) -> std::io::Result<Vec<(i32, i32, f32)>> {
let file = File::open(input_file)?;

let mut rdr = csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(file);
Expand Down Expand Up @@ -73,13 +101,13 @@ fn read_csv(input_file: &str) -> Vec<(i32, i32, f32)> {
}
}

data
Ok(data)
} else {
// Rectangular format - pass the first record and remaining iterator
read_rectangular_csv_from_records(first_record, records)
Ok(read_rectangular_csv_from_records(first_record, records))
}
} else {
Vec::new()
Ok(Vec::new())
}
}

Expand Down Expand Up @@ -130,29 +158,29 @@ fn read_rectangular_csv_from_records(
/// or by extracting unique origins and destinations from the input data.
///
/// # Arguments
/// * `arg` - The command-line arguments.
/// * `zones_file` - Optional path to the zones CSV file.
/// * `data` - The vector of tuples `(i32, i32, f32)` representing the input data.
///
/// # Returns
/// A sorted vector of unique zone numbers.
fn get_all_zones(arg: &[String], data: &[(i32, i32, f32)]) -> Vec<i32> {
if arg.len() > 3 {
let zone_file = File::open(&arg[3]).unwrap();
fn get_all_zones(zones_file: Option<&str>, data: &[(i32, i32, f32)]) -> std::io::Result<Vec<i32>> {
if let Some(zone_file) = zones_file {
let zone_file = File::open(zone_file)?;
let mut zone_rdr = csv::Reader::from_reader(zone_file);
let mut zones: Vec<i32> = zone_rdr
.records()
.filter_map(|result| result.ok()?.get(0)?.parse().ok())
.collect();
zones.sort_unstable();
zones
Ok(zones)
} else {
let zones: HashSet<i32> = data
.par_iter()
.flat_map(|(origin, destination, _)| vec![*origin, *destination])
.collect();
let mut zones: Vec<i32> = zones.into_iter().collect();
zones.sort_unstable();
zones
Ok(zones)
}
}

Expand Down Expand Up @@ -194,8 +222,8 @@ fn build_matrix(data: &[(i32, i32, f32)], all_zones: &[i32]) -> Vec<f32> {
///
/// # Panics
/// This function will panic if it fails to create or write to the output file.
fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) {
let output_file = File::create(output_file_name).unwrap();
fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) -> std::io::Result<()> {
let output_file = File::create(output_file_name)?;
let mut writer: Box<dyn Write> = if output_file_name.ends_with(".gz") {
Box::new(BufWriter::new(GzEncoder::new(output_file, Compression::default())))
} else {
Expand All @@ -204,41 +232,62 @@ fn write_mtx_file(output_file_name: &str, all_zones: &[i32], matrix: &[f32]) {

let zone_count = all_zones.len() as i32;

writer.write_all(&0xC4D4F1B2u32.to_le_bytes()).unwrap(); // Magic Number
writer.write_all(&1i32.to_le_bytes()).unwrap(); // Version Number
writer.write_all(&1i32.to_le_bytes()).unwrap(); // Type
writer.write_all(&2i32.to_le_bytes()).unwrap(); // Dimensions
writer.write_all(&zone_count.to_le_bytes()).unwrap(); // Index size for origin
writer.write_all(&zone_count.to_le_bytes()).unwrap(); // Index size for destination
writer.write_all(&0xC4D4F1B2u32.to_le_bytes())?; // Magic Number
writer.write_all(&1i32.to_le_bytes())?; // Version Number
writer.write_all(&1i32.to_le_bytes())?; // Type
writer.write_all(&2i32.to_le_bytes())?; // Dimensions
writer.write_all(&zone_count.to_le_bytes())?; // Index size for origin
writer.write_all(&zone_count.to_le_bytes())?; // Index size for destination

let is_little_endian = cfg!(target_endian = "little");

if is_little_endian {
// Write all origin zone numbers in a single call
let origin_zone_bytes: &[u8] = bytemuck::cast_slice(all_zones);
writer.write_all(origin_zone_bytes).unwrap(); // Zone Numbers for Origin
writer.write_all(origin_zone_bytes)?; // Zone Numbers for Origin

// Write all destination zone numbers in a single call
writer.write_all(origin_zone_bytes).unwrap(); // Zone Numbers for Destination
writer.write_all(origin_zone_bytes)?; // Zone Numbers for Destination

// Write all matrix values in a single call
let matrix_bytes: &[u8] = bytemuck::cast_slice(matrix);
writer.write_all(matrix_bytes).unwrap();
writer.write_all(matrix_bytes)?;

} else {
// Convert all_zones to little-endian
let origin_zone_bytes: Vec<u8> = all_zones
.par_iter()
.flat_map(|&zone| zone.to_le_bytes())
.collect();
writer.write_all(&origin_zone_bytes).unwrap(); // Zone Numbers for Origin
writer.write_all(&origin_zone_bytes).unwrap(); // Zone Numbers for Destination
writer.write_all(&origin_zone_bytes)?; // Zone Numbers for Origin
writer.write_all(&origin_zone_bytes)?; // Zone Numbers for Destination

// Convert matrix to little-endian
let matrix_bytes: Vec<u8> = matrix
.par_iter()
.flat_map(|&value| value.to_le_bytes())
.collect();
writer.write_all(&matrix_bytes).unwrap();
writer.write_all(&matrix_bytes)?;
}
Ok(())
}

// Write a test using test.csv to make sure that it converts to an mtx file
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csv_to_mtx() -> std::io::Result<()> {
let input_file = "test/test.csv";
let output_file = "test/test_output.mtx";

convert_csv_to_mtx(input_file, output_file, None)?;

// Compare against a known good output file
let expected_output_file = "test/test_expected.mtx";
let output_data = std::fs::read(output_file)?;
let expected_data = std::fs::read(expected_output_file)?;
assert_eq!(output_data, expected_data);
Ok(())
}
}
8 changes: 8 additions & 0 deletions test/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Origin,Destination,Value
1,1,0.1
1,2,0.2
1,3,0.3
2,1,1
2,2,2
2,3,3
4,4,0.1
Binary file added test/test_expected.mtx
Binary file not shown.