Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ rand = "0.8.5"
tch = { version = "0.18.0", optional = true, features = ["download-libtorch"] }
rayon = { version = "1.7.0", optional = true }
once_cell = { version = "1.17.1", optional = true }
crossbeam-channel = "0.5.8"


[dev-dependencies]
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ More info in the [documentation](https://docs.rs/ai-dataloader/).
Examples can be found in the [examples](examples/) folder but here there is a simple one

```rust
use ai_dataloader::DataLoader;
let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build();
use ai_dataloader::indexable::DataLoader;
use std::sync::Arc;
let dataset = vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")];
let loader = DataLoader::builder(Arc::new(dataset)).batch_size(2).shuffle().build();

for (label, text) in &loader {
println!("Label {label:?}");
Expand Down
10 changes: 7 additions & 3 deletions benches/throughput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use ndarray::Array3;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use std::sync::Arc;

const NUM_CLASS: usize = 20;
const IMAGE_SIZE: usize = 50;
const DATASET_LEN: usize = 500;

/// Dataset that return the same random image each time.
#[derive(Clone)]
pub struct RandomUnique {
image: Array3<u8>,
}
Expand Down Expand Up @@ -38,7 +40,7 @@ impl GetSample for RandomUnique {
}
}

fn iter_all_dataset(loader: &DataLoader<RandomUnique>) -> usize {
fn iter_all_dataset(loader: DataLoader<RandomUnique>) -> usize {
let mut num_sample = 0;
for (_sample, label) in loader.iter() {
num_sample += label.len();
Expand All @@ -47,15 +49,17 @@ fn iter_all_dataset(loader: &DataLoader<RandomUnique>) -> usize {
}

fn bench(c: &mut Criterion) {
let loader = DataLoader::builder(RandomUnique::default())
let loader = DataLoader::builder(Arc::new(RandomUnique::default()))
.batch_size(16)
.build();

const BYTES: u64 = DATASET_LEN as u64 * IMAGE_SIZE as u64 * IMAGE_SIZE as u64 * 3;

let mut group = c.benchmark_group("throughput-example");
group.throughput(Throughput::Bytes(BYTES));
group.bench_function("iter_all_dataset", |b| b.iter(|| iter_all_dataset(&loader)));
group.bench_function("iter_all_dataset", |b| {
b.iter(|| iter_all_dataset(loader.clone()))
});
group.finish();
}

Expand Down
Loading
Loading