Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions 1_nn/src/nn/patch_embd.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{Context, Distribution, NNError, NuralNetwork, TPTensor, Tensor, macros::destruct};
use crate::macros::dims;
use arg::Arg;
use arg::{Arg, Dim};
use tensor::digit_layout::DigitLayout;

#[derive(Clone)]
Expand Down Expand Up @@ -41,28 +41,25 @@ impl<T> NuralNetwork<T> for PatchEmbd<T> {
) -> Result<(Context<T>, Vec<Tensor<T>>), NNError> {
destruct!([x] = inputs);

dims!([n, c, h, w] = x);
let [n, c, height, width] = [n, c, h, w].map(|v| v.to_usize());
dims!([n, _c, height, width] = x);
let Self {
dt,
shape,
patch_embd,
patch_embd1,
} = self;
let [m, ck, hk, wk] = shape;
assert_eq!(n, 1);
assert_eq!(c, ck);
assert_eq!(hk, wk);
let [m, ck, hk, wk] = shape.map(Dim::from);
assert!(hk.eq(&wk));
let w = ctx.load_external(
"patch_embd",
dt,
[m.into(), ck.into(), hk.into(), wk.into()],
[m.clone(), ck.clone(), hk.clone(), wk.clone()],
patch_embd,
);
let w1 = ctx.load_external(
"patch_embd1",
dt,
[m.into(), ck.into(), hk.into(), wk.into()],
[m.clone(), ck.clone(), hk.clone(), wk.clone()],
patch_embd1,
);
let tensors = ctx
Expand All @@ -78,8 +75,8 @@ impl<T> NuralNetwork<T> for PatchEmbd<T> {

// reshape

let hp = (height / hk) as u64; // h patches
let wp = (width / wk) as u64; // w patches
let hp = height.clone() / hk.clone(); // h patches
let wp = width.clone() / wk.clone(); // w patches
// [n, m, hp, wp] -> [n, hp, wp, m]
destruct!(
[image_embd] = ctx
Expand All @@ -94,14 +91,18 @@ impl<T> NuralNetwork<T> for PatchEmbd<T> {
)
.unwrap()
);
// [n, hp, wp, m] -> [n * hp/2, 2, wp/2, 2*m]
destruct!(
[image_embd] = ctx
.call(
"",
"tile",
Some(Arg::dict([
("axis".into(), Arg::int(1)),
("tiles".into(), Arg::arr([hp / 2, 2].map(Arg::from)),)
(
"tiles".into(),
Arg::arr([hp.clone() / 2, Dim::from(2)].map(Arg::from)),
)
])),
[image_embd],
)
Expand All @@ -127,7 +128,10 @@ impl<T> NuralNetwork<T> for PatchEmbd<T> {
"tile",
Some(Arg::dict([
("axis".into(), Arg::int(2)),
("tiles".into(), Arg::arr([wp / 2, 2].map(Arg::from)),)
(
"tiles".into(),
Arg::arr([wp / 2, Dim::from(2)].map(Arg::from)),
)
])),
[image_embd],
)
Expand Down Expand Up @@ -160,14 +164,18 @@ impl<T> NuralNetwork<T> for PatchEmbd<T> {
)
.unwrap()
);
// [n * hp/2, wp/2, 2, 2*m] -> [n, hp * wp, m]
destruct!(
[image_embd] = ctx
.call(
"",
"tile",
Some(Arg::dict([
("axis".into(), Arg::int(0)),
("tiles".into(), Arg::arr([n as u64, hp / 2].map(Arg::from)),)
(
"tiles".into(),
Arg::arr([n.clone(), hp.clone() / 2].map(Arg::from)),
)
])),
[image_embd],
)
Expand All @@ -193,7 +201,10 @@ impl<T> NuralNetwork<T> for PatchEmbd<T> {
"tile",
Some(Arg::dict([
("axis".into(), Arg::int(2)),
("tiles".into(), Arg::arr([2, m as u64].map(Arg::from)),)
(
"tiles".into(),
Arg::arr([Dim::from(2), m.clone()].map(Arg::from)),
)
])),
[image_embd],
)
Expand Down
39 changes: 39 additions & 0 deletions 1_nn/src/op/merge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use super::{OpError, Operator, macros::*};
use crate::{Arg, Dim, TensorMeta};

pub struct Merge;

impl Operator for Merge {
fn infer(&self, inputs: &[TensorMeta], args: Option<&Arg>) -> Result<Vec<TensorMeta>, OpError> {
let Some(Arg::Dict(args)) = args else {
return Err(OpError::ArgError);
};
let Some(Arg::Int(start)) = args.get("start") else {
return Err(OpError::ArgError);
};
let Some(Arg::Int(len)) = args.get("len") else {
return Err(OpError::ArgError);
};

let start = *start as usize;
let end = start + *len as usize;

destruct!([x] = inputs);

let shape = x.shape();

if end > shape.len() {
return Err(OpError::ShapeError);
}

let merged_dim = shape[start..end]
.iter()
.fold(Dim::from(1), |acc, dim| acc * dim.clone());

let mut new_shape = shape[..start].to_vec();
new_shape.push(merged_dim);
new_shape.extend_from_slice(&shape[end..]);

Ok(vec![TensorMeta::new(x.dt, new_shape)])
}
}
3 changes: 3 additions & 0 deletions 1_nn/src/op/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ pub mod concat;
pub mod conv;
pub mod embedding;
pub mod linear;
pub mod merge;
pub mod mrope;
pub mod normalization;
pub mod rope;
pub mod split;
pub mod tile;
pub mod transpose;

/// 计算图层算子,只考虑形状推导
pub trait Operator {
Expand Down
50 changes: 50 additions & 0 deletions 1_nn/src/op/tile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use super::{OpError, Operator, macros::*};
use crate::{Arg, Dim, TensorMeta};

pub struct Tile;

impl Operator for Tile {
fn infer(&self, inputs: &[TensorMeta], args: Option<&Arg>) -> Result<Vec<TensorMeta>, OpError> {
let Some(Arg::Dict(args)) = args else {
return Err(OpError::ArgError);
};
let Some(Arg::Int(axis)) = args.get("axis") else {
return Err(OpError::ArgError);
};
let Some(Arg::Arr(tile)) = args.get("tile") else {
return Err(OpError::ArgError);
};

let axis = *axis as usize;
let tile = tile
.iter()
.map(|p| {
if let Arg::Dim(dim) = p {
Ok(dim.clone())
} else {
Err(OpError::ArgError)
}
})
.collect::<Result<Vec<_>, _>>()?;

destruct!([x] = inputs);

let shape = x.shape();

if axis >= shape.len() {
return Err(OpError::ShapeError);
}

let tile_product = tile.iter().fold(Dim::from(1), |acc, t| acc * t.clone());

if tile_product != shape[axis] {
return Err(OpError::ShapeError);
}

let mut new_shape = shape[..axis].to_vec();
new_shape.extend_from_slice(tile.as_slice());
new_shape.extend_from_slice(&shape[axis..]);

Ok(vec![TensorMeta::new(x.dt, new_shape)])
}
}
38 changes: 38 additions & 0 deletions 1_nn/src/op/transpose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use super::{OpError, Operator, macros::*};
use crate::{Arg, TensorMeta};

pub struct Transpose;

impl Operator for Transpose {
fn infer(&self, inputs: &[TensorMeta], args: Option<&Arg>) -> Result<Vec<TensorMeta>, OpError> {
let Some(Arg::Dict(args)) = args else {
return Err(OpError::ArgError);
};
let Some(Arg::Arr(perm)) = args.get("perm") else {
return Err(OpError::ArgError);
};

let perm = perm
.iter()
.map(|p| {
if let Arg::Int(perm) = p {
Ok(*perm as usize)
} else {
Err(OpError::ArgError)
}
})
.collect::<Result<Vec<_>, _>>()?;

destruct!([x] = inputs);

let shape = x.shape().to_vec();

if perm.len() != shape.len() {
return Err(OpError::ShapeError);
}

let new_shape = perm.iter().map(|&p| shape[p].clone()).collect::<Vec<_>>();

Ok(vec![TensorMeta::new(x.dt, new_shape)])
}
}
3 changes: 3 additions & 0 deletions 2_mem/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ impl<T> Graph<T> {
for (node, topo) in zip(&mut nodes, topo.iter()) {
match &*node.value.name {
"split" => op::split(node, topo, &mut edges),
"tile" => op::tile(node, topo, &mut edges),
"merge" => op::merge(node, topo, &mut edges),
"transpose" => op::transpose(node, topo, &mut edges),
"concat" => op::concat(node, topo, &mut edges),
_ => {}
}
Expand Down
108 changes: 108 additions & 0 deletions 2_mem/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,114 @@ pub(crate) fn split<T>(node: &mut Node, topo: NodeRef, edges: &mut [Edge<T>]) {
}
}

pub(crate) fn tile<T>(node: &mut Node, topo: NodeRef, edges: &mut [Edge<T>]) {
let NodeRef { inputs, outputs } = topo;
// tile 应该只有一个输入
let &[input] = inputs else { unreachable!() };
let input = edges[input].clone();
// 提取属性
let Some(Arg::Dict(arg)) = &node.value.arg else {
unreachable!()
};
let axis = arg["axis"].to_usize();
let Some(Arg::Arr(tile)) = arg.get("tile") else {
unreachable!()
};
let tile = tile
.iter()
.map(|p| {
if let Arg::Dim(dim) = p {
dim.to_usize()
} else {
unreachable!()
}
})
.collect::<Vec<_>>();
// 计算步长变换
assert_eq!(outputs.len(), 1); // tile 应该只有一个输出
for output in outputs {
let output = &mut edges[output];
// 暂时不支持 output 是外部的,因为外部 output 需要添加 rearrange kernel
assert!(matches!(&**output.get(), Info::Internal(_)));
// 用 tile_be 实现,并替换原来的边
*output = input
.clone()
.transform(|layout| layout.tile_be(axis, &tile));
}
// 算子擦除
node.value = Operator {
name: "empty".to_string(),
arg: None,
}
}

pub(crate) fn merge<T>(node: &mut Node, topo: NodeRef, edges: &mut [Edge<T>]) {
let NodeRef { inputs, outputs } = topo;
// merge 应该只有一个输入
let &[input] = inputs else { unreachable!() };
let input = edges[input].clone();
// 提取属性
let Some(Arg::Dict(arg)) = &node.value.arg else {
unreachable!()
};
let start = arg["start"].to_usize();
let len = arg["len"].to_usize();
// 计算步长变换
assert_eq!(outputs.len(), 1); // merge 应该只有一个输出
for output in outputs {
let output = &mut edges[output];
// 暂时不支持 output 是外部的,因为外部 output 需要添加 rearrange kernel
assert!(matches!(&**output.get(), Info::Internal(_)));
// 用 merge_be 实现,并替换原来的边
*output = input
.clone()
.transform(|layout| layout.merge_be(start, len).unwrap());
}
// 算子擦除
node.value = Operator {
name: "empty".to_string(),
arg: None,
}
}

pub(crate) fn transpose<T>(node: &mut Node, topo: NodeRef, edges: &mut [Edge<T>]) {
let NodeRef { inputs, outputs } = topo;
// transpose 应该只有一个输入
let &[input] = inputs else { unreachable!() };
let input = edges[input].clone();
// 提取属性
let Some(Arg::Dict(arg)) = &node.value.arg else {
unreachable!()
};
let Some(Arg::Arr(perm)) = arg.get("perm") else {
unreachable!()
};
let perm = perm
.iter()
.map(|p| {
if let Arg::Int(perm) = p {
*perm as usize
} else {
unreachable!()
}
})
.collect::<Vec<_>>();
// 计算步长变换
assert_eq!(outputs.len(), 1); // transpose 应该只有一个输出
for output in outputs {
let output = &mut edges[output];
// 暂时不支持 output 是外部的,因为外部 output 需要添加 rearrange kernel
assert!(matches!(&**output.get(), Info::Internal(_)));
// 用 transpose 实现,并替换原来的边
*output = input.clone().transform(|layout| layout.transpose(&perm));
}
// 算子擦除
node.value = Operator {
name: "empty".to_string(),
arg: None,
}
}

pub(crate) fn concat<T>(node: &mut Node, topo: NodeRef, edges: &mut [Edge<T>]) {
let NodeRef { inputs, outputs } = topo;
// concat 应该只有一个输出
Expand Down