From 98cf23386857974e4eef63e244c3f15822b7a3b0 Mon Sep 17 00:00:00 2001 From: cearX Date: Wed, 2 Jul 2025 15:40:27 +0800 Subject: [PATCH 1/3] =?UTF-8?q?fix(nn):=20=E4=BF=AE=E6=AD=A3=20patch=5Femb?= =?UTF-8?q?d=20=E5=B1=82=E5=BD=A2=E7=8A=B6=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 1_nn/src/nn/patch_embd.rs | 41 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/1_nn/src/nn/patch_embd.rs b/1_nn/src/nn/patch_embd.rs index 8453519..dc793f9 100644 --- a/1_nn/src/nn/patch_embd.rs +++ b/1_nn/src/nn/patch_embd.rs @@ -1,6 +1,6 @@ use super::{Context, Distribution, NNError, NuralNetwork, TPTensor, Tensor, macros::destruct}; use crate::macros::dims; -use arg::Arg; +use arg::{Arg, Dim}; use tensor::digit_layout::DigitLayout; #[derive(Clone)] @@ -41,28 +41,25 @@ impl NuralNetwork for PatchEmbd { ) -> Result<(Context, Vec>), NNError> { destruct!([x] = inputs); - dims!([n, c, h, w] = x); - let [n, c, height, width] = [n, c, h, w].map(|v| v.to_usize()); + dims!([n, _c, height, width] = x); let Self { dt, shape, patch_embd, patch_embd1, } = self; - let [m, ck, hk, wk] = shape; - assert_eq!(n, 1); - assert_eq!(c, ck); - assert_eq!(hk, wk); + let [m, ck, hk, wk] = shape.map(Dim::from); + assert!(hk.eq(&wk)); let w = ctx.load_external( "patch_embd", dt, - [m.into(), ck.into(), hk.into(), wk.into()], + [m.clone(), ck.clone(), hk.clone(), wk.clone()], patch_embd, ); let w1 = ctx.load_external( "patch_embd1", dt, - [m.into(), ck.into(), hk.into(), wk.into()], + [m.clone(), ck.clone(), hk.clone(), wk.clone()], patch_embd1, ); let tensors = ctx @@ -78,8 +75,8 @@ impl NuralNetwork for PatchEmbd { // reshape - let hp = (height / hk) as u64; // h patches - let wp = (width / wk) as u64; // w patches + let hp = height.clone() / hk.clone(); // h patches + let wp = width.clone() / wk.clone(); // w patches // [n, m, hp, wp] -> [n, hp, wp, m] destruct!( [image_embd] = ctx @@ -94,6 +91,7 @@ impl NuralNetwork for PatchEmbd { ) .unwrap() ); + // [n, hp, wp, m] -> [n * hp/2, 2, wp/2, 2*m] destruct!( [image_embd] = ctx .call( @@ -101,7 +99,10 @@ impl NuralNetwork for PatchEmbd { "tile", Some(Arg::dict([ ("axis".into(), Arg::int(1)), - ("tiles".into(), Arg::arr([hp / 2, 2].map(Arg::from)),) + ( + "tiles".into(), + Arg::arr([hp.clone() / 2, Dim::from(2)].map(Arg::from)), + ) ])), [image_embd], ) @@ -127,7 +128,10 @@ impl NuralNetwork for PatchEmbd { "tile", Some(Arg::dict([ ("axis".into(), Arg::int(2)), - ("tiles".into(), Arg::arr([wp / 2, 2].map(Arg::from)),) + ( + "tiles".into(), + Arg::arr([wp / 2, Dim::from(2)].map(Arg::from)), + ) ])), [image_embd], ) @@ -160,6 +164,7 @@ impl NuralNetwork for PatchEmbd { ) .unwrap() ); + // [n * hp/2, wp/2, 2, 2*m] -> [n, hp * wp, m] destruct!( [image_embd] = ctx .call( @@ -167,7 +172,10 @@ impl NuralNetwork for PatchEmbd { "tile", Some(Arg::dict([ ("axis".into(), Arg::int(0)), - ("tiles".into(), Arg::arr([n as u64, hp / 2].map(Arg::from)),) + ( + "tiles".into(), + Arg::arr([n.clone(), hp.clone() / 2].map(Arg::from)), + ) ])), [image_embd], ) @@ -193,7 +201,10 @@ impl NuralNetwork for PatchEmbd { "tile", Some(Arg::dict([ ("axis".into(), Arg::int(2)), - ("tiles".into(), Arg::arr([2, m as u64].map(Arg::from)),) + ( + "tiles".into(), + Arg::arr([Dim::from(2), m.clone()].map(Arg::from)), + ) ])), [image_embd], ) From 38d5b37fdecb5cd953f5ecab4a601f4b781c9f02 Mon Sep 17 00:00:00 2001 From: cearX Date: Tue, 8 Jul 2025 11:37:50 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat(nn):=20=E6=B7=BB=E5=8A=A0=20merge,=20t?= =?UTF-8?q?ile,=20transpose=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 1_nn/src/op/merge.rs | 39 +++++++++++++++++++++++++++++++ 1_nn/src/op/mod.rs | 3 +++ 1_nn/src/op/tile.rs | 50 ++++++++++++++++++++++++++++++++++++++++ 1_nn/src/op/transpose.rs | 38 ++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+) create mode 100644 1_nn/src/op/merge.rs create mode 100644 1_nn/src/op/tile.rs create mode 100644 1_nn/src/op/transpose.rs diff --git a/1_nn/src/op/merge.rs b/1_nn/src/op/merge.rs new file mode 100644 index 0000000..5a9b43c --- /dev/null +++ b/1_nn/src/op/merge.rs @@ -0,0 +1,39 @@ +use super::{OpError, Operator, macros::*}; +use crate::{Arg, Dim, TensorMeta}; + +pub struct Merge; + +impl Operator for Merge { + fn infer(&self, inputs: &[TensorMeta], args: Option<&Arg>) -> Result, OpError> { + let Some(Arg::Dict(args)) = args else { + return Err(OpError::ArgError); + }; + let Some(Arg::Int(start)) = args.get("start") else { + return Err(OpError::ArgError); + }; + let Some(Arg::Int(len)) = args.get("len") else { + return Err(OpError::ArgError); + }; + + let start = *start as usize; + let end = start + *len as usize; + + destruct!([x] = inputs); + + let shape = x.shape(); + + if end > shape.len() { + return Err(OpError::ShapeError); + } + + let merged_dim = shape[start..end] + .iter() + .fold(Dim::from(1), |acc, dim| acc * dim.clone()); + + let mut new_shape = shape[..start].to_vec(); + new_shape.push(merged_dim); + new_shape.extend_from_slice(&shape[end..]); + + Ok(vec![TensorMeta::new(x.dt, new_shape)]) + } +} diff --git a/1_nn/src/op/mod.rs b/1_nn/src/op/mod.rs index 60936ec..8a84565 100644 --- a/1_nn/src/op/mod.rs +++ b/1_nn/src/op/mod.rs @@ -7,10 +7,13 @@ pub mod concat; pub mod conv; pub mod embedding; pub mod linear; +pub mod merge; pub mod mrope; pub mod normalization; pub mod rope; pub mod split; +pub mod tile; +pub mod transpose; /// 计算图层算子,只考虑形状推导 pub trait Operator { diff --git a/1_nn/src/op/tile.rs b/1_nn/src/op/tile.rs new file mode 100644 index 0000000..62204b6 --- /dev/null +++ b/1_nn/src/op/tile.rs @@ -0,0 +1,50 @@ +use super::{OpError, Operator, macros::*}; +use crate::{Arg, Dim, TensorMeta}; + +pub struct Tile; + +impl Operator for Tile { + fn infer(&self, inputs: &[TensorMeta], args: Option<&Arg>) -> Result, OpError> { + let Some(Arg::Dict(args)) = args else { + return Err(OpError::ArgError); + }; + let Some(Arg::Int(axis)) = args.get("axis") else { + return Err(OpError::ArgError); + }; + let Some(Arg::Arr(tile)) = args.get("tile") else { + return Err(OpError::ArgError); + }; + + let axis = *axis as usize; + let tile = tile + .iter() + .map(|p| { + if let Arg::Dim(dim) = p { + Ok(dim.clone()) + } else { + Err(OpError::ArgError) + } + }) + .collect::, _>>()?; + + destruct!([x] = inputs); + + let shape = x.shape(); + + if axis >= shape.len() { + return Err(OpError::ShapeError); + } + + let tile_product = tile.iter().fold(Dim::from(1), |acc, t| acc * t.clone()); + + if tile_product != shape[axis] { + return Err(OpError::ShapeError); + } + + let mut new_shape = shape[..axis].to_vec(); + new_shape.extend_from_slice(tile.as_slice()); + new_shape.extend_from_slice(&shape[axis..]); + + Ok(vec![TensorMeta::new(x.dt, new_shape)]) + } +} diff --git a/1_nn/src/op/transpose.rs b/1_nn/src/op/transpose.rs new file mode 100644 index 0000000..b4b8e8a --- /dev/null +++ b/1_nn/src/op/transpose.rs @@ -0,0 +1,38 @@ +use super::{OpError, Operator, macros::*}; +use crate::{Arg, TensorMeta}; + +pub struct Transpose; + +impl Operator for Transpose { + fn infer(&self, inputs: &[TensorMeta], args: Option<&Arg>) -> Result, OpError> { + let Some(Arg::Dict(args)) = args else { + return Err(OpError::ArgError); + }; + let Some(Arg::Arr(perm)) = args.get("perm") else { + return Err(OpError::ArgError); + }; + + let perm = perm + .iter() + .map(|p| { + if let Arg::Int(perm) = p { + Ok(*perm as usize) + } else { + Err(OpError::ArgError) + } + }) + .collect::, _>>()?; + + destruct!([x] = inputs); + + let shape = x.shape().to_vec(); + + if perm.len() != shape.len() { + return Err(OpError::ShapeError); + } + + let new_shape = perm.iter().map(|&p| shape[p].clone()).collect::>(); + + Ok(vec![TensorMeta::new(x.dt, new_shape)]) + } +} From 41641ae04f9f3832d0a9940f478146644c474ad2 Mon Sep 17 00:00:00 2001 From: cearX Date: Tue, 8 Jul 2025 16:00:22 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat(mem):=20mem=20=E6=B7=BB=E5=8A=A0=20mer?= =?UTF-8?q?ge,=20tile,=20transpose=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 2_mem/src/lib.rs | 3 ++ 2_mem/src/op.rs | 108 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/2_mem/src/lib.rs b/2_mem/src/lib.rs index 83e565b..06eb470 100644 --- a/2_mem/src/lib.rs +++ b/2_mem/src/lib.rs @@ -38,6 +38,9 @@ impl Graph { for (node, topo) in zip(&mut nodes, topo.iter()) { match &*node.value.name { "split" => op::split(node, topo, &mut edges), + "tile" => op::tile(node, topo, &mut edges), + "merge" => op::merge(node, topo, &mut edges), + "transpose" => op::transpose(node, topo, &mut edges), "concat" => op::concat(node, topo, &mut edges), _ => {} } diff --git a/2_mem/src/op.rs b/2_mem/src/op.rs index f40aff7..9dbac6d 100644 --- a/2_mem/src/op.rs +++ b/2_mem/src/op.rs @@ -33,6 +33,114 @@ pub(crate) fn split(node: &mut Node, topo: NodeRef, edges: &mut [Edge]) { } } +pub(crate) fn tile(node: &mut Node, topo: NodeRef, edges: &mut [Edge]) { + let NodeRef { inputs, outputs } = topo; + // tile 应该只有一个输入 + let &[input] = inputs else { unreachable!() }; + let input = edges[input].clone(); + // 提取属性 + let Some(Arg::Dict(arg)) = &node.value.arg else { + unreachable!() + }; + let axis = arg["axis"].to_usize(); + let Some(Arg::Arr(tile)) = arg.get("tile") else { + unreachable!() + }; + let tile = tile + .iter() + .map(|p| { + if let Arg::Dim(dim) = p { + dim.to_usize() + } else { + unreachable!() + } + }) + .collect::>(); + // 计算步长变换 + assert_eq!(outputs.len(), 1); // tile 应该只有一个输出 + for output in outputs { + let output = &mut edges[output]; + // 暂时不支持 output 是外部的,因为外部 output 需要添加 rearrange kernel + assert!(matches!(&**output.get(), Info::Internal(_))); + // 用 tile_be 实现,并替换原来的边 + *output = input + .clone() + .transform(|layout| layout.tile_be(axis, &tile)); + } + // 算子擦除 + node.value = Operator { + name: "empty".to_string(), + arg: None, + } +} + +pub(crate) fn merge(node: &mut Node, topo: NodeRef, edges: &mut [Edge]) { + let NodeRef { inputs, outputs } = topo; + // merge 应该只有一个输入 + let &[input] = inputs else { unreachable!() }; + let input = edges[input].clone(); + // 提取属性 + let Some(Arg::Dict(arg)) = &node.value.arg else { + unreachable!() + }; + let start = arg["start"].to_usize(); + let len = arg["len"].to_usize(); + // 计算步长变换 + assert_eq!(outputs.len(), 1); // merge 应该只有一个输出 + for output in outputs { + let output = &mut edges[output]; + // 暂时不支持 output 是外部的,因为外部 output 需要添加 rearrange kernel + assert!(matches!(&**output.get(), Info::Internal(_))); + // 用 merge_be 实现,并替换原来的边 + *output = input + .clone() + .transform(|layout| layout.merge_be(start, len).unwrap()); + } + // 算子擦除 + node.value = Operator { + name: "empty".to_string(), + arg: None, + } +} + +pub(crate) fn transpose(node: &mut Node, topo: NodeRef, edges: &mut [Edge]) { + let NodeRef { inputs, outputs } = topo; + // transpose 应该只有一个输入 + let &[input] = inputs else { unreachable!() }; + let input = edges[input].clone(); + // 提取属性 + let Some(Arg::Dict(arg)) = &node.value.arg else { + unreachable!() + }; + let Some(Arg::Arr(perm)) = arg.get("perm") else { + unreachable!() + }; + let perm = perm + .iter() + .map(|p| { + if let Arg::Int(perm) = p { + *perm as usize + } else { + unreachable!() + } + }) + .collect::>(); + // 计算步长变换 + assert_eq!(outputs.len(), 1); // transpose 应该只有一个输出 + for output in outputs { + let output = &mut edges[output]; + // 暂时不支持 output 是外部的,因为外部 output 需要添加 rearrange kernel + assert!(matches!(&**output.get(), Info::Internal(_))); + // 用 transpose 实现,并替换原来的边 + *output = input.clone().transform(|layout| layout.transpose(&perm)); + } + // 算子擦除 + node.value = Operator { + name: "empty".to_string(), + arg: None, + } +} + pub(crate) fn concat(node: &mut Node, topo: NodeRef, edges: &mut [Edge]) { let NodeRef { inputs, outputs } = topo; // concat 应该只有一个输出