diff --git a/webgraph/src/utils/par_sort_graph.rs b/webgraph/src/utils/par_sort_graph.rs index cf0b82f7..f12d87e0 100644 --- a/webgraph/src/utils/par_sort_graph.rs +++ b/webgraph/src/utils/par_sort_graph.rs @@ -12,7 +12,6 @@ use std::marker::PhantomData; use std::num::NonZeroUsize; use std::path::Path; -use std::sync::Mutex; use anyhow::{ensure, Context, Result}; use dsi_bitstream::traits::NE; @@ -281,18 +280,15 @@ impl ParSortGraph { let presort_tmp_dir = tempfile::tempdir().context("Could not create temporary directory")?; - let unsorted_pairs = unsorted_pairs.into_iter(); - let num_blocks = unsorted_pairs.len(); + let presort_tmp_dir = &presort_tmp_dir; - let partitioned_presorted_pairs = Mutex::new(vec![Vec::new(); num_blocks]); - - std::thread::scope(|s| { - let partitioned_presorted_pairs = &partitioned_presorted_pairs; - let presort_tmp_dir = &presort_tmp_dir; - for (block_id, pair) in unsorted_pairs.enumerate() { - let deserializer = deserializer.clone(); - let mut pl = pl.clone(); - s.spawn(move || { + let partitioned_presorted_pairs = unsorted_pairs + .into_iter() + .enumerate() + .par_bridge() + .map_init( + || (deserializer.clone(), pl.clone()), + |(deserializer, pl), (block_id, pair)| { let mut unsorted_buffers = (0..num_partitions) .map(|_| Vec::with_capacity(batch_size)) .collect::>(); @@ -300,11 +296,11 @@ impl ParSortGraph { (0..num_partitions).map(|_| Vec::new()).collect::>(); for (src, dst, label) in pair { - /* ensure!( + ensure!( src < self.num_nodes, "Expected {}, but got {src}", self.num_nodes - ); */ + ); let partition_id = src / num_nodes_per_partition; let sorted_pairs = &mut sorted_pairs[partition_id]; @@ -320,8 +316,7 @@ impl ParSortGraph { sorted_pairs, buf, ) - .context("Could not flush buffer") - .unwrap(); + .context("Could not flush buffer")?; assert!(buf.is_empty(), "flush_buffer did not empty the buffer"); pl.update_with_count(buf_len); } @@ -347,43 +342,37 @@ impl ParSortGraph { &mut pairs, &mut buf, ) - .context("Could not flush buffer at the end") - .unwrap(); + .context("Could not flush buffer at the end")?; assert!(buf.is_empty(), "flush_buffer did not empty the buffer"); pl.update_with_count(buf_len); } - // TODO: ugly - partitioned_presorted_pairs.lock().unwrap()[block_id] = sorted_pairs; - }); - } - }); + Ok(sorted_pairs) + }, + ) + .collect::>>()?; // At this point, the iterator could be collected into // {worker_id -> {partition_id -> [iterators]}} // ie. Vec>>>. // // Let's merge the {partition_id -> [iterators]} maps of each worker - let partitioned_presorted_pairs = partitioned_presorted_pairs - .into_inner() - .unwrap() - .into_par_iter() - .reduce( - || (0..num_partitions).map(|_| Vec::new()).collect(), - |mut pair_partitions1: Vec>>, - pair_partitions2: Vec>>| - -> Vec>> { - assert_eq!(pair_partitions1.len(), num_partitions); - assert_eq!(pair_partitions2.len(), num_partitions); - for (partition1, partition2) in pair_partitions1 - .iter_mut() - .zip(pair_partitions2.into_iter()) - { - partition1.extend(partition2.into_iter()); - } - pair_partitions1 - }, - ); + let partitioned_presorted_pairs = partitioned_presorted_pairs.into_par_iter().reduce( + || (0..num_partitions).map(|_| Vec::new()).collect(), + |mut pair_partitions1: Vec>>, + pair_partitions2: Vec>>| + -> Vec>> { + assert_eq!(pair_partitions1.len(), num_partitions); + assert_eq!(pair_partitions2.len(), num_partitions); + for (partition1, partition2) in pair_partitions1 + .iter_mut() + .zip(pair_partitions2.into_iter()) + { + partition1.extend(partition2.into_iter()); + } + pair_partitions1 + }, + ); // At this point, the iterator was turned into // {partition_id -> [iterators]} // ie. Vec>>.