diff --git a/.cargo/config.toml b/.cargo/config.toml
index 6330d419..1ecf5a7b 100644
--- a/.cargo/config.toml
+++ b/.cargo/config.toml
@@ -1,5 +1,5 @@
# Needed for WASM unstable features
[build]
-rustflags = [ "--cfg=web_sys_unstable_apis" ]
-rustdocflags = [ "--cfg=web_sys_unstable_apis" ]
-#target = "wasm32-unknown-unknown"
+rustflags = ["--cfg=web_sys_unstable_apis"]
+rustdocflags = ["--cfg=web_sys_unstable_apis"]
+target = "wasm32-unknown-unknown"
diff --git a/.gitignore b/.gitignore
index 3ffe4ce7..ed4b9df2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,3 +30,9 @@ venv/
# proptest regression tests
proptest-regressions/
+
+# samply profile
+profile.json
+
+# webdriver configs
+webdriver.json
diff --git a/Cargo.toml b/Cargo.toml
index eae21a07..d69df143 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -2,12 +2,11 @@
members = [
"crates/ratchet-hub",
"crates/ratchet-core",
- "crates/ratchet-web",
+ "crates/ratchet-web-train",
"crates/ratchet-loader",
"crates/ratchet-models",
"crates/ratchet-nn",
"crates/ratchet-hub",
- "crates/ratchet-cli",
"crates/ratchet-macros",
"crates/ratchet-datasets",
]
@@ -28,7 +27,7 @@ inherits = "release"
debug = 2
[workspace.dependencies]
-wgpu = { git = "https://github.com/vinhowe/wgpu", branch = "feature/multi-dim-compute-subgroups", features = [
+wgpu = { git = "https://github.com/vinhowe/wgpu", branch = "feature/extract-webgpu-gpubuffer", features = [
"fragile-send-sync-non-atomic-wasm",
] }
bytemuck = { version = "1.14.0", features = [
@@ -49,7 +48,7 @@ anyhow = "1.0.79"
tokenizers = "0.19.1"
js-sys = "0.3.64"
-wasm-bindgen = "0.2.91"
+wasm-bindgen = "0.2.100"
wasm-bindgen-test = "0.3.34"
cfg-if = "1.0.0"
chrono = "0.4.35"
diff --git a/README.md b/README.md
index 11980782..5c94cccf 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,8 @@
-
- For the time being, I don't have instructions on how to run this. You're on your own for a tiny bit :)
-
+# toy transformer
-
-

-
-(backward)
-
-
-
+Watch a toy model train, in your browser, with WebGPU.
+
+## Attribution
This is a fork of [Ratchet](https://github.com/huggingface/ratchet), hacked and butchered to add backpropogation, to show that it is technically possible to train language models in a (WebGPU-enabled) browser.
diff --git a/crates/ratchet-cli/Cargo.toml b/crates/ratchet-cli/Cargo.toml
deleted file mode 100644
index 014bf678..00000000
--- a/crates/ratchet-cli/Cargo.toml
+++ /dev/null
@@ -1,26 +0,0 @@
-[package]
-name = "ratchet-cli"
-version = "0.1.0"
-edition = "2021"
-
-[[bin]]
-name = "ratchet"
-path = "src/bin/cli.rs"
-
-[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
-ratchet = { path = "../ratchet-core" }
-ratchet-loader = { path = "../ratchet-loader" }
-ratchet-models = { path = "../ratchet-models" }
-ratchet-hub = { path = "../ratchet-hub" }
-ratchet-nn = { path = "../ratchet-nn" }
-log.workspace = true
-clap = { workspace = true, features = ["derive"] }
-hf-hub = { workspace = true }
-serde_json = { workspace = true }
-env_logger = { workspace = true }
-fern = { workspace = true }
-chrono = { workspace = true }
-tokenizers = { workspace = true }
-ndarray = { workspace = true }
-ndarray-stats = { workspace = true }
-anyhow.workspace = true
diff --git a/crates/ratchet-cli/src/bin/cli.rs b/crates/ratchet-cli/src/bin/cli.rs
deleted file mode 100644
index 08eb3e3b..00000000
--- a/crates/ratchet-cli/src/bin/cli.rs
+++ /dev/null
@@ -1,253 +0,0 @@
-#[cfg(not(target_arch = "wasm32"))]
-mod cli {
- use clap::{value_parser, Arg, ArgMatches, Command};
- use hf_hub::api::sync::Api;
- use ndarray::Axis;
- use ndarray_stats::QuantileExt;
- use ratchet::{shape, Device, DeviceRequest, Tensor};
- use ratchet_loader::gguf::gguf::{self, Header};
- use ratchet_models::registry::{
- AvailableModels, Quantization, WhisperVariants as RegistryWhisper,
- };
- use ratchet_models::whisper::options::DecodingOptionsBuilder;
- use ratchet_models::whisper::transcribe::transcribe;
- use ratchet_models::{phi2::Phi2, whisper::Whisper};
- use ratchet_nn::Module;
- use std::io::Write;
- use std::path::Path;
- use std::process::Command as TermCommand;
- use tokenizers::Tokenizer;
-
- fn ffmpeg_preproc>(path: P) -> Vec {
- let path = path.as_ref();
- let output = TermCommand::new("ffmpeg")
- .args([
- "-nostdin",
- "-threads",
- "0",
- "-i",
- path.to_str().unwrap(),
- "-f",
- "s16le",
- "-ac",
- "1",
- "-acodec",
- "pcm_s16le",
- "-loglevel",
- "error",
- "-ar",
- "16000",
- "-",
- ])
- .output()
- .expect("Failed to execute ffmpeg command");
-
- if !output.status.success() {
- panic!(
- "ffmpeg command failed: {}",
- String::from_utf8_lossy(&output.stderr)
- );
- }
-
- let audio_data = output.stdout;
- let mut samples = Vec::new();
-
- for chunk in audio_data.chunks(2) {
- let sample = i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0;
- samples.push(sample);
- }
-
- samples
- }
-
- pub fn start_logger() {
- let logger = fern::Dispatch::new()
- .format(|out, message, record| {
- out.finish(format_args!(
- "{}[{}][{}] {}",
- chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
- record.target(),
- record.level(),
- message
- ))
- })
- .level_for("tokenizers", log::LevelFilter::Off)
- .level(log::LevelFilter::Warn)
- .apply();
- match logger {
- Ok(_) => log::info!("Logging initialized."),
- Err(error) => eprintln!("Error initializing logging: {:?}", error),
- }
- }
-
- fn handle_whisper(matches: &ArgMatches, api: Api) {
- let quantization = matches
- .get_one::("quantization")
- .unwrap_or(&Quantization::Q8_0);
-
- let mut whisper = if let Some(variant) = matches.get_one::("variant") {
- let model = AvailableModels::Whisper(variant.clone());
- let repo = api.model(model.repo_id());
- let model_path = repo.get(&model.model_id(quantization.clone())).unwrap();
- println!("MODEL PATH: {}", model_path.display());
-
- let mut reader = std::io::BufReader::new(std::fs::File::open(model_path).unwrap());
- let device = Device::request_device(DeviceRequest::GPU).unwrap();
- let header = gguf::Header::read(&mut reader).unwrap();
- Whisper::load(header, variant.clone(), &mut reader, device).unwrap()
- } else {
- panic!("Model not found");
- };
-
- if let Some(input) = matches.get_one::("input") {
- let options = DecodingOptionsBuilder::new().build();
- let samples = ffmpeg_preproc(input);
- let transcript =
- transcribe(&mut whisper, samples, options, Some(|s| println!("{}", s))).unwrap();
- log::info!("Processing time: {:?}", transcript.processing_time);
- } else {
- panic!("Input file not found");
- };
- }
-
- fn handle_phi2(matches: &ArgMatches, api: Api) -> anyhow::Result<()> {
- let _ = env_logger::builder().is_test(true).try_init();
- let model_repo = api.model("FL33TW00D-HF/phi2".to_string());
- let model_path = model_repo.get("phi2-q8_0.gguf").unwrap();
- println!("MODEL PATH: {}", model_path.display());
- let mut reader = std::io::BufReader::new(std::fs::File::open(model_path)?);
- let device = Device::request_device(DeviceRequest::GPU)?;
- let content = Header::read(&mut reader)?;
- let mut model = Phi2::load(content, &mut reader, &device)?;
-
- let tokenizer =
- Tokenizer::from_file(concat!("../../", "/models/microsoft/phi-2/tokenizer.json"))
- .unwrap();
-
- let prompt = if let Some(prompt) = matches.get_one::("prompt") {
- prompt
- } else {
- "def print_prime(n):"
- };
-
- let max_tokens = matches.get_one::("max-tokens").unwrap();
-
- let encoding = tokenizer.encode(prompt, true).unwrap();
- let mut tokens = encoding
- .get_ids()
- .iter()
- .map(|&x| x as i32)
- .collect::>();
-
- print!("{}", prompt);
- std::io::stdout().flush().unwrap();
- let mut all_tokens = tokens.clone();
- let mut loop_cnt = 0;
- let start_time = std::time::Instant::now();
- while tokens[tokens.len() - 1] != 50256 && loop_cnt < *max_tokens {
- let input = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
- let result = model.schedule(input)?.full()?.resolve()?;
- let logits = result.to(&Device::CPU)?;
- model.cache_mut().update(tokens.len());
-
- tokens = logits
- .to_ndarray_view::()
- .map_axis(Axis(2), |row| row.argmax_skipnan().unwrap())
- .iter()
- .map(|&x| x as i32)
- .collect::>();
- let u32_toks = tokens.iter().map(|&x| x as u32).collect::>();
- print!("{}", tokenizer.decode(&u32_toks, true).unwrap());
- std::io::stdout().flush().unwrap();
- all_tokens.extend(tokens.clone());
- loop_cnt += 1;
- }
- let elapsed = start_time.elapsed();
- println!("\nElapsed time: {:?}", elapsed);
- println!(
- "tok/sec: {}",
- all_tokens.len() as f64 / elapsed.as_secs_f64()
- );
- Ok(())
- }
-
- #[cfg(not(target_arch = "wasm32"))]
- fn main() -> Result<(), Box> {
- env_logger::init();
- let matches = Command::new("ratchet")
- .about("LLM & VLM CLI")
- .version("0.1.0")
- .subcommand_required(true)
- .arg_required_else_help(true)
- .subcommand(
- Command::new("whisper")
- .long_about(
- "Cross-platform, GPU accelerated implementation of OpenAI's Whisper Model.",
- )
- .arg(
- Arg::new("variant")
- .short('v')
- .long("variant")
- .default_value("small")
- .help("Whisper model variant to use.")
- .value_parser(value_parser!(RegistryWhisper)),
- )
- .arg(
- Arg::new("quantization")
- .short('q')
- .long("quantization")
- .default_value("f32")
- .help("Model quantization to use.")
- .value_parser(value_parser!(Quantization)),
- )
- .arg(
- Arg::new("input")
- .short('i')
- .long("input")
- .required(true)
- .help("Path to the input file"),
- ),
- )
- .subcommand(
- Command::new("phi2")
- .long_about(
- "Cross-platform, GPU accelerated implementation of Microsoft's Phi2 model.",
- )
- .arg(
- Arg::new("prompt")
- .short('p')
- .long("prompt")
- .required(true)
- .help("Input prompt."),
- )
- .arg(
- Arg::new("max-tokens")
- .short('m')
- .long("max-tokens")
- .default_value("256")
- .value_parser(value_parser!(usize))
- .help("Maximum number of tokens to generate."),
- ),
- )
- .get_matches();
-
- let api = Api::new().unwrap();
- if let Some(matches) = matches.subcommand_matches("phi2") {
- let _ = handle_phi2(matches, api);
- } else if let Some(matches) = matches.subcommand_matches("whisper") {
- handle_whisper(matches, api);
- }
-
- Ok(())
- }
-}
-
-#[cfg(not(target_arch = "wasm32"))]
-fn main() {
- cli::main();
-}
-
-#[cfg(target_arch = "wasm32")]
-fn main() {
- // Empty function to get the CLI to compile anyway
-}
diff --git a/crates/ratchet-cli/src/lib.rs b/crates/ratchet-cli/src/lib.rs
deleted file mode 100644
index 8b137891..00000000
--- a/crates/ratchet-cli/src/lib.rs
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml
index ef2e9f68..7433c476 100644
--- a/crates/ratchet-core/Cargo.toml
+++ b/crates/ratchet-core/Cargo.toml
@@ -71,6 +71,10 @@ maybe-async = { workspace = true }
async-trait = "0.1.77"
smallvec = { workspace = true, features = ["serde"] }
+[target.'cfg(target_arch = "wasm32")'.dependencies.web-sys]
+features = ["GpuBuffer", "GpuDevice"]
+workspace = true
+
[dev-dependencies]
env_logger = { workspace = true }
rand = { workspace = true }
diff --git a/crates/ratchet-core/src/backprop.rs b/crates/ratchet-core/src/backprop.rs
index 875977e7..afc8f425 100644
--- a/crates/ratchet-core/src/backprop.rs
+++ b/crates/ratchet-core/src/backprop.rs
@@ -3,9 +3,9 @@
/// Methods for backpropagation of gradients.
use crate::ops::{BinaryOp, UnaryOp};
use crate::{
- rvec, Affine, Alibi, Binary, Broadcast, Cmp, Concat, Conv, DType, Gather, GroupNorm, IndexAdd,
- IndexSelect, LazyOp, Matmul, Norm, NormOp, Permute, Powf, Reduce, ReduceOp, Reindex, RoPE,
- ScatterAdd, Shape, Slice, Softmax, Tensor, TensorId, Unary, View, WhereCond,
+ rvec, Affine, Alibi, Binary, Broadcast, Cast, Cmp, Concat, Conv, DType, Gather, GroupNorm,
+ IndexAdd, IndexSelect, LazyOp, Matmul, Norm, NormOp, Permute, Powf, Reduce, ReduceOp, Reindex,
+ RoPE, ScatterAdd, ScopePusher, Shape, Slice, Softmax, Tensor, TensorId, Unary, View, WhereCond,
};
use crate::{HashMap, Trilu};
use anyhow::Result;
@@ -174,8 +174,16 @@ impl Tensor {
track_grad |= tg;
nodes
}
+ LazyOp::Cast(Cast { input, .. }) => {
+ if input.dt().is_float() {
+ let (tg, nodes) = walk(input, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ } else {
+ nodes
+ }
+ }
LazyOp::IndexWrite(_) => todo!(),
- LazyOp::Cast(_) => todo!(),
LazyOp::Copy(_) => todo!(),
LazyOp::Detach(_)
| LazyOp::Const
@@ -186,6 +194,7 @@ impl Tensor {
})
| LazyOp::FillConstant(_)
| LazyOp::FillRandn(_)
+ | LazyOp::Bernoulli(_)
| LazyOp::Arange(_)
| LazyOp::Cache(_)
| LazyOp::Trilu(_) => nodes,
@@ -203,10 +212,12 @@ impl Tensor {
}
pub fn backward(&self) -> Result {
+ let _scope_guard = ScopePusher::new("backward");
let sorted_nodes = self.sorted_nodes();
let mut grads = GradStore::new();
grads.insert(self, self.ones_like::().contiguous());
for node in sorted_nodes.iter() {
+ let _op_scope_guard = ScopePusher::new(&format!("for:{}", node.op().name()));
if node.is_variable() {
continue;
}
@@ -574,6 +585,7 @@ impl Tensor {
})
| LazyOp::FillConstant(_)
| LazyOp::FillRandn(_)
+ | LazyOp::Bernoulli(_)
| LazyOp::Arange(_) => {}
LazyOp::View(View { src: arg, .. }) => {
let arg_grad = grad.clone().view(arg.shape().clone())?;
@@ -687,12 +699,14 @@ impl Tensor {
LazyOp::Alibi(Alibi { input, .. }) => {
grads.accumulate_add(input, grad)?;
}
+ LazyOp::Cast(Cast { input, dst_dt }) => {
+ grads.accumulate_add(input, grad.cast(input.dt())?)?;
+ }
LazyOp::Norm(_) => todo!(),
LazyOp::Const => panic!("ratchet internal error - const node in backprop"),
LazyOp::Concat(_) => todo!(),
LazyOp::Cmp(_) => todo!(),
LazyOp::Powf(_) => todo!(),
- LazyOp::Cast(_) => todo!(),
LazyOp::RoPE(RoPE {
input: arg,
dim,
diff --git a/crates/ratchet-core/src/compiled_op.rs b/crates/ratchet-core/src/compiled_op.rs
index 09876e93..46f7b7c6 100644
--- a/crates/ratchet-core/src/compiled_op.rs
+++ b/crates/ratchet-core/src/compiled_op.rs
@@ -1,12 +1,15 @@
#[cfg(feature = "debug")]
use crate::gpu::BindGroupLayoutEntryDescriptor;
-use crate::gpu::{
- BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup, WgpuDevice,
- WorkgroupCount,
-};
#[cfg(feature = "debug")]
use crate::TensorId;
use crate::{drvec, rvec, KernelKey, OperationError, PooledGPUBuffer, RVec, Tensor};
+use crate::{
+ gpu::{
+ BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup,
+ WgpuDevice, WorkgroupCount,
+ },
+ TensorId,
+};
use derive_new::new;
use std::sync::Arc;
use wgpu::DynamicOffset;
@@ -40,6 +43,11 @@ pub struct CompiledOp {
pub(crate) storage_groups: RVec,
pub(crate) offset: DynamicOffset, //offset into the metadata uniform buffer
pub kernel_key: KernelKey,
+ // Mapping between tensor and compiled op is not necessarily 1:1—for example: things like AMP
+ // likely insert casts between tensor operations.
+ pub tensor_id: Option,
+ #[cfg(not(feature = "debug"))]
+ pub debug_buffer: Option>,
#[cfg(feature = "debug")]
pub debug_buffer: Option<(TensorId, Arc)>,
#[cfg(feature = "debug")]
diff --git a/crates/ratchet-core/src/cpu/binary.rs b/crates/ratchet-core/src/cpu/binary.rs
index 146d80f4..b927fa61 100644
--- a/crates/ratchet-core/src/cpu/binary.rs
+++ b/crates/ratchet-core/src/cpu/binary.rs
@@ -2,6 +2,7 @@ use crate::cpu::cpu_store_result;
use crate::{Binary, BinaryOp, CPUOperation, DType, OperationError, Tensor, TensorDType};
use core::marker::PhantomData;
use half::{bf16, f16};
+use maybe_async::maybe_async;
use num_traits::NumOps;
#[inline]
@@ -32,14 +33,15 @@ pub(crate) fn binary_map_inplace(lhs: &mut [T], rhs: &[T], f: fn
}
#[inline]
-pub(crate) fn binary_apply(
+#[maybe_async]
+pub(crate) async fn binary_apply(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> U,
) -> Result<(), OperationError> {
- let lhs = lhs.to_vec::()?;
- let rhs = rhs.to_vec::()?;
+ let lhs = lhs.to_vec::().await?;
+ let rhs = rhs.to_vec::().await?;
let mut result = vec![U::zero(); dst.shape().numel()];
binary_map(&lhs, &rhs, &mut result, f);
cpu_store_result(dst, &result);
@@ -47,14 +49,15 @@ pub(crate) fn binary_apply(
}
#[inline]
-pub(crate) fn binary_apply_inplace(
+#[maybe_async]
+pub(crate) async fn binary_apply_inplace(
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
f: fn(T, T) -> T,
) -> Result<(), OperationError> {
- let mut lhs = lhs.to_vec::()?;
- let rhs = rhs.to_vec::()?;
+ let mut lhs = lhs.to_vec::().await?;
+ let rhs = rhs.to_vec::().await?;
binary_map_inplace(&mut lhs, &rhs, f);
cpu_store_result(dst, &lhs);
Ok(())
@@ -66,8 +69,13 @@ pub struct BinaryOps {
macro_rules! impl_cpu_binary_op {
($method_name:ident, $dtype:ident, $op:expr) => {
- fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result {
- binary_apply_inplace::<$dtype>(lhs, rhs, &dst, $op)?;
+ #[maybe_async]
+ async fn $method_name(
+ lhs: &Tensor,
+ rhs: &Tensor,
+ dst: Tensor,
+ ) -> Result {
+ binary_apply_inplace::<$dtype>(lhs, rhs, &dst, $op).await?;
Ok(dst)
}
};
@@ -95,24 +103,28 @@ macro_rules! impl_cpu_binary {
impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs);
impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs);
- pub fn apply(op: &Binary, dst: Tensor) -> Result {
+ #[maybe_async]
+ pub async fn apply(op: &Binary, dst: Tensor) -> Result {
match op.op() {
- BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst),
- BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst),
- BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst),
- BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst),
+ BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst).await,
+ BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst).await,
+ BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst).await,
+ BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst).await,
}
}
}
};
}
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Binary {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match dst.dt() {
- DType::F32 => BinaryOps::::apply(self, dst),
- DType::F16 => BinaryOps::::apply(self, dst),
- DType::BF16 => BinaryOps::::apply(self, dst),
+ DType::F32 => BinaryOps::::apply(self, dst).await,
+ DType::F16 => BinaryOps::::apply(self, dst).await,
+ DType::BF16 => BinaryOps::::apply(self, dst).await,
_ => todo!(),
}
}
diff --git a/crates/ratchet-core/src/cpu/gemm.rs b/crates/ratchet-core/src/cpu/gemm.rs
index b932b37b..6e68ea34 100644
--- a/crates/ratchet-core/src/cpu/gemm.rs
+++ b/crates/ratchet-core/src/cpu/gemm.rs
@@ -6,6 +6,7 @@ use anyhow::{anyhow, Result};
use core::str::FromStr;
use gemm::{gemm as gemm_kernel, Parallelism};
use half::{bf16, f16};
+use maybe_async::maybe_async;
use std::num::NonZeroUsize;
fn get_num_threads() -> NonZeroUsize {
@@ -155,16 +156,20 @@ fn gemm_impl(
)
}
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Matmul {
- fn apply_cpu(&self, dst: Tensor) -> Result {
- fn run_gemm(
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn run_gemm(
spec: MatmulSpec,
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
) -> Result<(), OperationError> {
- let lhs = lhs.to_vec::()?;
- let rhs = rhs.to_vec::()?;
+ let lhs = lhs.to_vec::().await?;
+ let rhs = rhs.to_vec::().await?;
let result = if spec.trans_dst() {
gemm_impl::(spec, &rhs, &lhs)?
@@ -179,9 +184,9 @@ impl CPUOperation for Matmul {
let Matmul { lhs, rhs, .. } = self;
match self.lhs.dt() {
- DType::F32 => run_gemm::(spec, lhs, rhs, &dst),
- DType::F16 => run_gemm::(spec, lhs, rhs, &dst),
- DType::BF16 => run_gemm::(spec, lhs, rhs, &dst),
+ DType::F32 => run_gemm::(spec, lhs, rhs, &dst).await,
+ DType::F16 => run_gemm::(spec, lhs, rhs, &dst).await,
+ DType::BF16 => run_gemm::(spec, lhs, rhs, &dst).await,
dtype => Err(InvariantError::UnsupportedDType(dtype))?,
}?;
Ok(dst)
diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs
index 67c6f370..fed802dc 100644
--- a/crates/ratchet-core/src/cpu/mod.rs
+++ b/crates/ratchet-core/src/cpu/mod.rs
@@ -13,27 +13,29 @@ use crate::{
};
use anyhow::anyhow;
use half::{bf16, f16};
+use maybe_async::maybe_async;
use rope::cpu_rope;
use unary::unary_apply_fn;
use utils::cpu_store_result;
-pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result {
+#[maybe_async]
+pub async fn apply_operation(op: LazyOp, dst: Tensor) -> Result {
match op {
- LazyOp::Binary(b) => b.apply_cpu(dst),
- LazyOp::Cast(c) => cpu_cast(c, dst),
- LazyOp::Matmul(m) => m.apply_cpu(dst),
- LazyOp::Softmax(s) => s.apply_cpu(dst),
- LazyOp::RoPE(r) => cpu_rope(r, dst),
+ LazyOp::Binary(b) => b.apply_cpu(dst).await,
+ LazyOp::Cast(c) => cpu_cast(c, dst).await,
+ LazyOp::Matmul(m) => m.apply_cpu(dst).await,
+ LazyOp::Softmax(s) => s.apply_cpu(dst).await,
+ LazyOp::RoPE(r) => cpu_rope(r, dst).await,
LazyOp::Alibi(a) => todo!(),
- LazyOp::Unary(u) => u.apply_cpu(dst),
- LazyOp::Reindex(r) => r.apply_cpu(dst),
- LazyOp::Concat(c) => cpu_concat(c, dst),
- LazyOp::Norm(n) => n.apply_cpu(dst),
+ LazyOp::Unary(u) => u.apply_cpu(dst).await,
+ LazyOp::Reindex(r) => r.apply_cpu(dst).await,
+ LazyOp::Concat(c) => cpu_concat(c, dst).await,
+ LazyOp::Norm(n) => n.apply_cpu(dst).await,
LazyOp::Affine(_a) => todo!(),
LazyOp::Cmp(_c) => todo!(),
LazyOp::Powf(_p) => todo!(),
LazyOp::Conv(_c) => todo!(),
- LazyOp::Select(i) => cpu_index_select(i, dst),
+ LazyOp::Select(i) => cpu_index_select(i, dst).await,
LazyOp::IndexWrite(_i) => todo!(),
LazyOp::Cache(_c) => todo!(),
LazyOp::Trilu(_t) => todo!(),
@@ -44,6 +46,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result todo!(),
LazyOp::FillConstant(_f) => todo!(),
LazyOp::FillRandn(_f) => todo!(),
+ LazyOp::Bernoulli(_b) => todo!(),
LazyOp::Arange(_a) => todo!(),
LazyOp::IndexAdd(_i) => todo!(),
LazyOp::ScatterAdd(_s) => todo!(),
@@ -52,11 +55,14 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result Result;
+ async fn apply_cpu(&self, dst: Tensor) -> Result;
}
-fn index_select(
+#[maybe_async]
+async fn index_select(
index_select: IndexSelect,
dst: Tensor,
) -> Result {
@@ -84,8 +90,8 @@ fn index_select(
let left_len: usize = dst_dims[..dim].iter().product();
let right_len: usize = dst_dims[dim + 1..].iter().product();
- let src = src.to_vec::()?;
- let indices = indices.to_vec::()?;
+ let src = src.to_vec::().await?;
+ let indices = indices.to_vec::().await?;
let mut result = vec![T::zero(); dst_len];
for left_i in 0..left_len {
@@ -102,67 +108,77 @@ fn index_select(
Ok(dst)
}
-fn qindex_select(op: IndexSelect, dst: Tensor) -> Result {
+#[maybe_async]
+async fn qindex_select(op: IndexSelect, dst: Tensor) -> Result {
// NOTE: qindex_select is functional but not optimized at all.
// Currently we simply dequantize the entire input tensor to f32 and then call index_select.
// Because of borrowing rules dequantizing also requires a deep clone of the input tensor, which is less than ideal.
// In the future we would rather directly index the raw buffer of the quantized tensor and dequantize only what is required.
// TODO: Add support for direct indexing + partial dequantization
- let src = op.src().deep_clone();
+ let src = op.src().deep_clone().await;
// NOTE: Support for other quantization types is dependent on the corresponding dequantization functions.
let src = dequantize(src);
let indices = op.indices().clone();
let dim = op.dim();
- index_select::(IndexSelect::new(src, indices, dim), dst)
+ index_select::(IndexSelect::new(src, indices, dim), dst).await
}
-pub fn cpu_index_select(i: IndexSelect, dst: Tensor) -> Result {
+#[maybe_async]
+pub async fn cpu_index_select(i: IndexSelect, dst: Tensor) -> Result {
match i.src().dt() {
- DType::F32 => index_select::(i, dst),
- DType::F16 => index_select::(i, dst),
- DType::BF16 => index_select::(i, dst),
- DType::Q8_0F(_) => qindex_select(i, dst),
+ DType::F32 => index_select::(i, dst).await,
+ DType::F16 => index_select::(i, dst).await,
+ DType::BF16 => index_select::(i, dst).await,
+ DType::Q8_0F(_) => qindex_select(i, dst).await,
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}
-fn direct_cast(
+#[maybe_async]
+async fn direct_cast(
input: &Tensor,
dst: &Tensor,
) -> Result<(), OperationError> {
- let input = input.to_vec::()?;
+ let input = input.to_vec::().await?;
let result =
bytemuck::try_cast_slice::(&input).map_err(|_| anyhow!("Failed direct cast"))?;
cpu_store_result(dst, result);
Ok(())
}
-pub fn cpu_cast(cast: Cast, dst: Tensor) -> Result {
+#[maybe_async]
+pub async fn cpu_cast(cast: Cast, dst: Tensor) -> Result {
if cast.input().dt() == cast.dst_dt() {
return Ok(cast.input().clone());
}
match (cast.input().dt(), cast.dst_dt()) {
// F32 ->
- (DType::F32, DType::F16) => unary_apply_fn::(cast.input(), &dst, f16::from_f32)?,
+ (DType::F32, DType::F16) => {
+ unary_apply_fn::(cast.input(), &dst, f16::from_f32).await?
+ }
(DType::F32, DType::BF16) => {
- unary_apply_fn::(cast.input(), &dst, bf16::from_f32)?
+ unary_apply_fn::(cast.input(), &dst, bf16::from_f32).await?
}
- (DType::F32, DType::I32) => direct_cast::(cast.input(), &dst)?,
- (DType::F32, DType::U32) => direct_cast::(cast.input(), &dst)?,
+ (DType::F32, DType::I32) => direct_cast::(cast.input(), &dst).await?,
+ (DType::F32, DType::U32) => direct_cast::(cast.input(), &dst).await?,
// F16 ->
- (DType::F16, DType::F32) => unary_apply_fn::(cast.input(), &dst, f32::from)?,
+ (DType::F16, DType::F32) => {
+ unary_apply_fn::(cast.input(), &dst, f32::from).await?
+ }
// BF16 ->
- (DType::BF16, DType::F32) => unary_apply_fn::(cast.input(), &dst, f32::from)?,
+ (DType::BF16, DType::F32) => {
+ unary_apply_fn::(cast.input(), &dst, f32::from).await?
+ }
// I32 ->
- (DType::I32, DType::F32) => direct_cast::(cast.input(), &dst)?,
+ (DType::I32, DType::F32) => direct_cast::(cast.input(), &dst).await?,
// U32 ->
- (DType::U32, DType::F32) => direct_cast::(cast.input(), &dst)?,
+ (DType::U32, DType::F32) => direct_cast::(cast.input(), &dst).await?,
_ => unimplemented!("Cannot cast {:?} -> {:?}", cast.input().dt(), cast.dst_dt()),
};
@@ -195,7 +211,9 @@ pub(crate) fn concat(
}
Ok(())
}
-pub(crate) fn apply_concat(
+
+#[maybe_async]
+pub(crate) async fn apply_concat(
inputs: RVec,
dim: usize,
dst: Tensor,
@@ -203,24 +221,33 @@ pub(crate) fn apply_concat(
let dst_size = dst.shape().numel();
let mut result = vec![T::zero(); dst_size];
- let inputs = inputs
- .iter()
- .map(|t| match t.to_vec::() {
+ let mut inputs_result = Vec::with_capacity(inputs.len());
+ for t in inputs.iter() {
+ let result = match t.to_vec::().await {
Ok(v) => Ok((t.shape().clone(), v)),
Err(e) => Err(e.into()),
- })
- .collect::, OperationError>>();
+ };
+ if let Err(e) = result {
+ return Err(e);
+ }
+ inputs_result.push(result.unwrap());
+ }
+ let inputs = inputs_result;
- concat(&inputs?, dim, dst.shape(), &mut result)?;
+ concat(&inputs, dim, dst.shape(), &mut result)?;
cpu_store_result(&dst, &result);
Ok(dst)
}
-pub fn cpu_concat(Concat { inputs, dim }: Concat, dst: Tensor) -> Result {
+#[maybe_async]
+pub async fn cpu_concat(
+ Concat { inputs, dim }: Concat,
+ dst: Tensor,
+) -> Result {
match dst.dt() {
- DType::F32 => apply_concat::(inputs, dim, dst),
- DType::F16 => apply_concat::(inputs, dim, dst),
- DType::BF16 => apply_concat::(inputs, dim, dst),
+ DType::F32 => apply_concat::(inputs, dim, dst).await,
+ DType::F16 => apply_concat::(inputs, dim, dst).await,
+ DType::BF16 => apply_concat::(inputs, dim, dst).await,
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}
diff --git a/crates/ratchet-core/src/cpu/norm.rs b/crates/ratchet-core/src/cpu/norm.rs
index ce3389ab..27b8e1de 100644
--- a/crates/ratchet-core/src/cpu/norm.rs
+++ b/crates/ratchet-core/src/cpu/norm.rs
@@ -9,20 +9,25 @@ use crate::{
};
use core::iter::Sum;
use half::{bf16, f16};
+use maybe_async::maybe_async;
use num::Float;
use num_traits::NumOps;
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for NormOp {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match self {
- NormOp::LayerNorm(n) => apply_layer_norm(n, dst),
- NormOp::RMSNorm(n) => apply_rms_norm(n, dst),
- NormOp::GroupNorm(g) => apply_group_norm(g, dst),
+ NormOp::LayerNorm(n) => apply_layer_norm(n, dst).await,
+ NormOp::RMSNorm(n) => apply_rms_norm(n, dst).await,
+ NormOp::GroupNorm(g) => apply_group_norm(g, dst).await,
}
}
}
-fn apply_layer_norm(
+#[maybe_async]
+async fn apply_layer_norm(
Norm {
input,
scale,
@@ -49,9 +54,9 @@ fn apply_layer_norm(
}
match input.dt() {
- DType::F32 => layer_norm::(input, scale, bias, *eps, &dst)?,
- DType::F16 => layer_norm::(input, scale, bias, *eps, &dst)?,
- DType::BF16 => layer_norm::(input, scale, bias, *eps, &dst)?,
+ DType::F32 => layer_norm::(input, scale, bias, *eps, &dst).await?,
+ DType::F16 => layer_norm::(input, scale, bias, *eps, &dst).await?,
+ DType::BF16 => layer_norm::(input, scale, bias, *eps, &dst).await?,
_ => todo!(),
};
@@ -88,7 +93,8 @@ where
result
}
-fn layer_norm(
+#[maybe_async]
+async fn layer_norm(
input: &Tensor,
scale: &Tensor,
bias: &Option,
@@ -103,10 +109,10 @@ where
let N = src_shape[rank - 1];
let norm_shape = shape!(N);
- let input = input.to_vec::()?;
- let scale = scale.to_vec::()?;
+ let input = input.to_vec::().await?;
+ let scale = scale.to_vec::().await?;
let bias = match bias {
- Some(b) => Some(b.to_vec::()?),
+ Some(b) => Some(b.to_vec::().await?),
None => None,
};
@@ -146,7 +152,8 @@ where
Ok(())
}
-fn apply_rms_norm(
+#[maybe_async]
+async fn apply_rms_norm(
Norm {
input,
scale,
@@ -173,16 +180,22 @@ fn apply_rms_norm(
}
match input.dt() {
- DType::F32 => rms_norm::(input, scale, *eps, &dst)?,
- DType::F16 => rms_norm::(input, scale, *eps, &dst)?,
- DType::BF16 => rms_norm::(input, scale, *eps, &dst)?,
+ DType::F32 => rms_norm::(input, scale, *eps, &dst).await?,
+ DType::F16 => rms_norm::(input, scale, *eps, &dst).await?,
+ DType::BF16 => rms_norm::(input, scale, *eps, &dst).await?,
_ => todo!(),
};
Ok(dst)
}
-fn rms_norm(input: &Tensor, scale: &Tensor, eps: f32, dst: &Tensor) -> Result<(), OperationError>
+#[maybe_async]
+async fn rms_norm(
+ input: &Tensor,
+ scale: &Tensor,
+ eps: f32,
+ dst: &Tensor,
+) -> Result<(), OperationError>
where
T: TensorDType + Float + NumOps + for<'a> Sum<&'a T>,
{
@@ -190,8 +203,8 @@ where
let rank = input.rank();
let N = src_shape[rank - 1];
- let mut x = input.to_vec::()?;
- let scale = scale.to_vec::()?;
+ let mut x = input.to_vec::().await?;
+ let scale = scale.to_vec::().await?;
let mut x2 = x.clone();
square(&mut x2);
@@ -212,7 +225,8 @@ where
Ok(())
}
-fn apply_group_norm(_n: &GroupNorm, dst: Tensor) -> Result {
+#[maybe_async]
+async fn apply_group_norm(_n: &GroupNorm, dst: Tensor) -> Result {
//let result = norm(&b.src.to_vec::()?, b.src.shape(), b.to());
//cpu_store_result(&dst, &result);
Ok(dst)
diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs
index 3f33425e..db7089aa 100644
--- a/crates/ratchet-core/src/cpu/reindex.rs
+++ b/crates/ratchet-core/src/cpu/reindex.rs
@@ -4,34 +4,42 @@ use crate::{
Tensor, TensorDType,
};
use half::{bf16, f16};
+use maybe_async::maybe_async;
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Reindex {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match self {
- Reindex::Permute(p) => p.apply_cpu(dst),
- Reindex::Slice(s) => s.apply_cpu(dst),
- Reindex::Broadcast(b) => b.apply_cpu(dst),
+ Reindex::Permute(p) => p.apply_cpu(dst).await,
+ Reindex::Slice(s) => s.apply_cpu(dst).await,
+ Reindex::Broadcast(b) => b.apply_cpu(dst).await,
}
}
}
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Permute {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match dst.dt() {
- DType::F32 => apply_permute::(self, dst),
- DType::BF16 => apply_permute::(self, dst),
- DType::F16 => apply_permute::(self, dst),
- DType::I32 => apply_permute::(self, dst),
- DType::U32 => apply_permute::(self, dst),
+ DType::F32 => apply_permute::(self, dst).await,
+ DType::BF16 => apply_permute::(self, dst).await,
+ DType::F16 => apply_permute::(self, dst).await,
+ DType::I32 => apply_permute::(self, dst).await,
+ DType::U32 => apply_permute::(self, dst).await,
_ => todo!(),
}
}
}
-fn apply_permute(p: &Permute, dst: Tensor) -> Result {
+#[maybe_async]
+async fn apply_permute(p: &Permute, dst: Tensor) -> Result {
let perm: [usize; 4] = p.promote().as_slice().try_into().unwrap();
let Permute { src, dims: _ } = p;
- let result = permute(&src.to_vec::()?, src.shape(), dst.shape(), perm);
+ let result = permute(&src.to_vec::().await?, src.shape(), dst.shape(), perm);
cpu_store_result(&dst, &result);
Ok(dst)
}
@@ -70,22 +78,26 @@ fn permute(
result
}
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Slice {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match dst.dt() {
- DType::F32 => apply_slice::(self, dst),
- DType::BF16 => apply_slice::(self, dst),
- DType::F16 => apply_slice::(self, dst),
- DType::I32 => apply_slice::(self, dst),
- DType::U32 => apply_slice::(self, dst),
+ DType::F32 => apply_slice::(self, dst).await,
+ DType::BF16 => apply_slice::(self, dst).await,
+ DType::F16 => apply_slice::(self, dst).await,
+ DType::I32 => apply_slice::(self, dst).await,
+ DType::U32 => apply_slice::(self, dst).await,
_ => todo!(),
}
}
}
-fn apply_slice(s: &Slice, dst: Tensor) -> Result {
+#[maybe_async]
+async fn apply_slice(s: &Slice, dst: Tensor) -> Result {
let (start, stop): (Vec<_>, Vec<_>) = s.indices().iter().map(|r| (r.start, r.end)).unzip();
- let result = slice(&s.src.to_vec::()?, s.src.strides(), &start, &stop);
+ let result = slice(&s.src.to_vec::().await?, s.src.strides(), &start, &stop);
cpu_store_result(&dst, &result);
Ok(dst)
@@ -127,21 +139,28 @@ pub(crate) fn slice(
dst
}
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Broadcast {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match dst.dt() {
- DType::F32 => apply_broadcast::(self, dst),
- DType::BF16 => apply_broadcast::(self, dst),
- DType::F16 => apply_broadcast::(self, dst),
- DType::I32 => apply_broadcast::(self, dst),
- DType::U32 => apply_broadcast::(self, dst),
+ DType::F32 => apply_broadcast::(self, dst).await,
+ DType::BF16 => apply_broadcast::(self, dst).await,
+ DType::F16 => apply_broadcast::(self, dst).await,
+ DType::I32 => apply_broadcast::(self, dst).await,
+ DType::U32 => apply_broadcast::(self, dst).await,
_ => todo!(),
}
}
}
-fn apply_broadcast(b: &Broadcast, dst: Tensor) -> Result {
- let result = broadcast(&b.src.to_vec::()?, b.src.shape(), b.to());
+#[maybe_async]
+async fn apply_broadcast(
+ b: &Broadcast,
+ dst: Tensor,
+) -> Result {
+ let result = broadcast(&b.src.to_vec::().await?, b.src.shape(), b.to());
cpu_store_result(&dst, &result);
Ok(dst)
}
diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs
index a1d407f0..bc50d96b 100644
--- a/crates/ratchet-core/src/cpu/rope.rs
+++ b/crates/ratchet-core/src/cpu/rope.rs
@@ -3,14 +3,16 @@ use crate::{
cpu::{cpu_store_result, gemm::gemm, reindex::slice},
shape, DType, OperationError, RoPE, Shape, Strides, Tensor,
};
+use maybe_async::maybe_async;
-pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result {
+#[maybe_async]
+pub async fn cpu_rope(op: RoPE, dst: Tensor) -> Result {
match op.input().dt() {
DType::F32 => {
let dim = op.dim();
let base = op.base();
let offset = op.offset();
- let src = op.input().to_vec::()?;
+ let src = op.input().to_vec::().await?;
let result = rope(src, op.input().shape(), dim, base, offset)?;
cpu_store_result(&dst, &result)
}
diff --git a/crates/ratchet-core/src/cpu/softmax.rs b/crates/ratchet-core/src/cpu/softmax.rs
index 1d6a3df0..d4aafa54 100644
--- a/crates/ratchet-core/src/cpu/softmax.rs
+++ b/crates/ratchet-core/src/cpu/softmax.rs
@@ -1,16 +1,20 @@
use crate::cpu::utils::cpu_store_result;
use crate::{CPUOperation, DType, OperationError, Softmax, Tensor, TensorDType};
use half::{bf16, f16};
+use maybe_async::maybe_async;
use num::Float;
use num_traits::NumAssignOps;
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Softmax {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
let Softmax { input, dim } = self;
match input.dt() {
- DType::F32 => softmax::(input, *dim, &dst)?,
- DType::F16 => softmax::(input, *dim, &dst)?,
- DType::BF16 => softmax::(input, *dim, &dst)?,
+ DType::F32 => softmax::(input, *dim, &dst).await?,
+ DType::F16 => softmax::(input, *dim, &dst).await?,
+ DType::BF16 => softmax::(input, *dim, &dst).await?,
_ => todo!(),
}
@@ -18,12 +22,13 @@ impl CPUOperation for Softmax {
}
}
-fn softmax(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError>
+#[maybe_async]
+async fn softmax(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError>
where
T: TensorDType + Float + NumAssignOps,
{
let src_shape = input.shape();
- let mut input = input.to_vec::()?;
+ let mut input = input.to_vec::().await?;
let N = src_shape[dim];
input.chunks_mut(N).for_each(|chunk| {
let mut sum = T::zero();
diff --git a/crates/ratchet-core/src/cpu/unary.rs b/crates/ratchet-core/src/cpu/unary.rs
index 54271b30..0554757f 100644
--- a/crates/ratchet-core/src/cpu/unary.rs
+++ b/crates/ratchet-core/src/cpu/unary.rs
@@ -2,6 +2,7 @@ use crate::cpu::cpu_store_result;
use crate::{CPUOperation, DType, OperationError, Tensor, TensorDType, Unary, UnaryOp};
use core::marker::PhantomData;
use half::{bf16, f16};
+use maybe_async::maybe_async;
use num_traits::Float;
#[inline]
@@ -24,12 +25,13 @@ pub(crate) fn unary_map_inplace(src: &mut [T], f: fn(T) -> T) {
}
#[inline]
-pub(crate) fn unary_apply_fn(
+#[maybe_async]
+pub(crate) async fn unary_apply_fn(
input: &Tensor,
dst: &Tensor,
f: fn(T) -> U,
) -> Result<(), OperationError> {
- let input = input.to_vec::()?;
+ let input = input.to_vec::().await?;
let mut result = vec![U::zero(); dst.shape().numel()];
unary_apply_fn_helper(&input, &mut result, f);
cpu_store_result(dst, &result);
@@ -70,26 +72,27 @@ macro_rules! impl_unary_ops {
* x
* ($conv(1.0) / ($conv(1.0) + (-x).exp())));
- fn apply(op: &Unary, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply(op: &Unary, dst: Tensor) -> Result {
match op.op() {
- UnaryOp::Gelu => Self::gelu(op.input(), dst),
- UnaryOp::Tanh => Self::tanh(op.input(), dst),
- UnaryOp::Exp => Self::exp(op.input(), dst),
- UnaryOp::Log => Self::log(op.input(), dst),
- UnaryOp::Sin => Self::sin(op.input(), dst),
- UnaryOp::Cos => Self::cos(op.input(), dst),
- UnaryOp::Abs => Self::abs(op.input(), dst),
- UnaryOp::Square => Self::square(op.input(), dst),
- UnaryOp::Sqrt => Self::sqrt(op.input(), dst),
- UnaryOp::Relu => Self::relu(op.input(), dst),
- UnaryOp::Relu2 => Self::relu2(op.input(), dst),
- UnaryOp::Floor => Self::floor(op.input(), dst),
- UnaryOp::Ceil => Self::ceil(op.input(), dst),
- UnaryOp::Neg => Self::neg(op.input(), dst),
- UnaryOp::Reciprocal => Self::reciprocal(op.input(), dst),
- UnaryOp::Silu => Self::silu(op.input(), dst),
- UnaryOp::Sigmoid => Self::sigmoid(op.input(), dst),
- UnaryOp::Swiglu => Self::swiglu(op.input(), dst),
+ UnaryOp::Gelu => Self::gelu(op.input(), dst).await,
+ UnaryOp::Tanh => Self::tanh(op.input(), dst).await,
+ UnaryOp::Exp => Self::exp(op.input(), dst).await,
+ UnaryOp::Log => Self::log(op.input(), dst).await,
+ UnaryOp::Sin => Self::sin(op.input(), dst).await,
+ UnaryOp::Cos => Self::cos(op.input(), dst).await,
+ UnaryOp::Abs => Self::abs(op.input(), dst).await,
+ UnaryOp::Square => Self::square(op.input(), dst).await,
+ UnaryOp::Sqrt => Self::sqrt(op.input(), dst).await,
+ UnaryOp::Relu => Self::relu(op.input(), dst).await,
+ UnaryOp::Relu2 => Self::relu2(op.input(), dst).await,
+ UnaryOp::Floor => Self::floor(op.input(), dst).await,
+ UnaryOp::Ceil => Self::ceil(op.input(), dst).await,
+ UnaryOp::Neg => Self::neg(op.input(), dst).await,
+ UnaryOp::Reciprocal => Self::reciprocal(op.input(), dst).await,
+ UnaryOp::Silu => Self::silu(op.input(), dst).await,
+ UnaryOp::Sigmoid => Self::sigmoid(op.input(), dst).await,
+ UnaryOp::Swiglu => Self::swiglu(op.input(), dst).await,
}
}
}
@@ -98,19 +101,23 @@ macro_rules! impl_unary_ops {
macro_rules! impl_cpu_unary_op {
($method_name:ident, $op:expr) => {
- fn $method_name(input: &Tensor, dst: Tensor) -> Result {
- unary_apply_fn(input, &dst, $op)?;
+ #[maybe_async]
+ async fn $method_name(input: &Tensor, dst: Tensor) -> Result {
+ unary_apply_fn(input, &dst, $op).await?;
Ok(dst)
}
};
}
+#[maybe_async(AFIT)]
+#[cfg_attr(target_arch = "wasm32", async_trait::async_trait)]
impl CPUOperation for Unary {
- fn apply_cpu(&self, dst: Tensor) -> Result {
+ #[maybe_async]
+ async fn apply_cpu(&self, dst: Tensor) -> Result {
match dst.dt() {
- DType::F32 => UnaryOps::::apply(self, dst),
- DType::F16 => UnaryOps::::apply(self, dst),
- DType::BF16 => UnaryOps::::apply(self, dst),
+ DType::F32 => UnaryOps::::apply(self, dst).await,
+ DType::F16 => UnaryOps::::apply(self, dst).await,
+ DType::BF16 => UnaryOps::::apply(self, dst).await,
_ => todo!(),
}
}
@@ -118,7 +125,7 @@ impl CPUOperation for Unary {
macro_rules! impl_cpu_unary {
($dtype:ident) => {
- impl_cpu_unary!($dtype, |x| x);
+ impl_cpu_unary!($dtype, |x: $dtype| -> $dtype { x });
};
($dtype:ident, $conv:expr) => {
impl_unary_ops!($dtype, $conv);
diff --git a/crates/ratchet-core/src/executable.rs b/crates/ratchet-core/src/executable.rs
index 150c8eda..8f1fbae3 100644
--- a/crates/ratchet-core/src/executable.rs
+++ b/crates/ratchet-core/src/executable.rs
@@ -1,21 +1,26 @@
use crate::gpu::{GpuUniform, PoolError, StaticResourcePoolAccessor, WgpuDevice};
-use crate::{Compiled, Storage};
+use crate::{
+ Compiled, CompiledCopy, CompiledOp, ExportedTensorProfilingEntry, HashMap, StepLogConfig,
+ Storage, TensorId,
+};
+
// #[cfg(feature = "debug")] (for tensor?)
+use crate::gpu::Profiler;
use crate::Tensor;
#[cfg(feature = "debug")]
-use crate::{CPUBuffer, DType, HashMap, TensorId};
+use crate::{DType, HashMap, TensorId};
#[cfg(feature = "debug")]
+use crate::{DeviceStorage, KernelKey, RVec};
+use derive_new::new;
use maybe_async::maybe_async;
#[cfg(feature = "debug")]
use parking_lot::RwLock;
#[cfg(feature = "debug")]
use slotmap::Key;
-#[cfg(feature = "debug")]
+#[cfg(not(feature = "debug"))]
+use std::collections::BTreeMap;
+use std::iter::Peekable;
use std::sync::Arc;
-
-#[cfg(feature = "debug")]
-use crate::{wgpu_buffer_to_cpu_buffer, DeviceStorage, KernelKey, RVec};
-use derive_new::new;
use wgpu::SubmissionIndex;
#[cfg(feature = "debug")]
@@ -36,6 +41,8 @@ pub struct Executable {
storage: Option>>,
pub(crate) steps: Vec,
gpu_uniform: GpuUniform,
+ #[cfg(not(feature = "debug"))]
+ pub(crate) debug_list: Option>,
#[cfg(feature = "debug")]
pub(crate) debug_list: Option