Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 32 additions & 43 deletions webgraph/src/utils/par_sort_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Mutex;

use anyhow::{ensure, Context, Result};
use dsi_bitstream::traits::NE;
Expand Down Expand Up @@ -281,30 +280,27 @@ impl<L> ParSortGraph<L> {
let presort_tmp_dir =
tempfile::tempdir().context("Could not create temporary directory")?;

let unsorted_pairs = unsorted_pairs.into_iter();
let num_blocks = unsorted_pairs.len();
let presort_tmp_dir = &presort_tmp_dir;

let partitioned_presorted_pairs = Mutex::new(vec![Vec::new(); num_blocks]);

std::thread::scope(|s| {
let partitioned_presorted_pairs = &partitioned_presorted_pairs;
let presort_tmp_dir = &presort_tmp_dir;
for (block_id, pair) in unsorted_pairs.enumerate() {
let deserializer = deserializer.clone();
let mut pl = pl.clone();
s.spawn(move || {
let partitioned_presorted_pairs = unsorted_pairs
.into_iter()
.enumerate()
.par_bridge()
.map_init(
|| (deserializer.clone(), pl.clone()),
|(deserializer, pl), (block_id, pair)| {
let mut unsorted_buffers = (0..num_partitions)
.map(|_| Vec::with_capacity(batch_size))
.collect::<Vec<_>>();
let mut sorted_pairs =
(0..num_partitions).map(|_| Vec::new()).collect::<Vec<_>>();

for (src, dst, label) in pair {
/* ensure!(
ensure!(
src < self.num_nodes,
"Expected {}, but got {src}",
self.num_nodes
); */
);
let partition_id = src / num_nodes_per_partition;

let sorted_pairs = &mut sorted_pairs[partition_id];
Expand All @@ -320,8 +316,7 @@ impl<L> ParSortGraph<L> {
sorted_pairs,
buf,
)
.context("Could not flush buffer")
.unwrap();
.context("Could not flush buffer")?;
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
pl.update_with_count(buf_len);
}
Expand All @@ -347,43 +342,37 @@ impl<L> ParSortGraph<L> {
&mut pairs,
&mut buf,
)
.context("Could not flush buffer at the end")
.unwrap();
.context("Could not flush buffer at the end")?;
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
pl.update_with_count(buf_len);
}

// TODO: ugly
partitioned_presorted_pairs.lock().unwrap()[block_id] = sorted_pairs;
});
}
});
Ok(sorted_pairs)
},
)
.collect::<Result<Vec<_>>>()?;

// At this point, the iterator could be collected into
// {worker_id -> {partition_id -> [iterators]}}
// ie. Vec<Vec<Vec<BatchIterator>>>>.
//
// Let's merge the {partition_id -> [iterators]} maps of each worker
let partitioned_presorted_pairs = partitioned_presorted_pairs
.into_inner()
.unwrap()
.into_par_iter()
.reduce(
|| (0..num_partitions).map(|_| Vec::new()).collect(),
|mut pair_partitions1: Vec<Vec<BatchIterator<D>>>,
pair_partitions2: Vec<Vec<BatchIterator<D>>>|
-> Vec<Vec<BatchIterator<D>>> {
assert_eq!(pair_partitions1.len(), num_partitions);
assert_eq!(pair_partitions2.len(), num_partitions);
for (partition1, partition2) in pair_partitions1
.iter_mut()
.zip(pair_partitions2.into_iter())
{
partition1.extend(partition2.into_iter());
}
pair_partitions1
},
);
let partitioned_presorted_pairs = partitioned_presorted_pairs.into_par_iter().reduce(
|| (0..num_partitions).map(|_| Vec::new()).collect(),
|mut pair_partitions1: Vec<Vec<BatchIterator<D>>>,
pair_partitions2: Vec<Vec<BatchIterator<D>>>|
-> Vec<Vec<BatchIterator<D>>> {
assert_eq!(pair_partitions1.len(), num_partitions);
assert_eq!(pair_partitions2.len(), num_partitions);
for (partition1, partition2) in pair_partitions1
.iter_mut()
.zip(pair_partitions2.into_iter())
{
partition1.extend(partition2.into_iter());
}
pair_partitions1
},
);
// At this point, the iterator was turned into
// {partition_id -> [iterators]}
// ie. Vec<Vec<BatchIterator>>>.
Expand Down
Loading