Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,120 @@ namespace deepx::tf
return 0;
}
};

template <typename Author>
class Mul : public TF
{
public:
Mul(vector<Param> args, vector<Param> returns)
{
this->name = "mul";
this->author = Author::name();
this->args = args;
this->returns = returns;
}
string math_formula() const override
{
return "T3=T1*T2";
}
shared_ptr<TF> clone() const override
{
return make_shared<Mul<Author>>(*this);
}
int run(shared_ptr<MemBase> mem, string &error) override
{
Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype;
Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype;
Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype;
if (a_type != b_type || a_type != c_type)
{
error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " != " + precision_str(c_type);
return 1;
}
switch (a_type)
{
case Precision::Float64:
tensorfunc::mul<Author, double>(*mem->gettensor<double>(this->args[0].textvalue), *mem->gettensor<double>(this->args[1].textvalue), *mem->gettensor<double>(this->returns[0].textvalue));
break;
case Precision::Float32:
tensorfunc::mul<Author, float>(*mem->gettensor<float>(this->args[0].textvalue), *mem->gettensor<float>(this->args[1].textvalue), *mem->gettensor<float>(this->returns[0].textvalue));
break;
case Precision::Int64:
tensorfunc::mul<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
break;
case Precision::Int32:
tensorfunc::mul<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
break;
case Precision::Int16:
tensorfunc::mul<Author, int16_t>(*mem->gettensor<int16_t>(this->args[0].textvalue), *mem->gettensor<int16_t>(this->args[1].textvalue), *mem->gettensor<int16_t>(this->returns[0].textvalue));
break;
case Precision::Int8:
tensorfunc::mul<Author, int8_t>(*mem->gettensor<int8_t>(this->args[0].textvalue), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int8_t>(this->returns[0].textvalue));
break;
default:
error = "Unsupported dtype: " + precision_str(a_type);
return 1;
}
return 0;
}
};

template <typename Author>
class MulScalar : public TF
{
public:
MulScalar(vector<Param> args, vector<Param> returns)
{
this->name = "mulscalar";
this->author = Author::name();
this->args = args;
this->returns = returns;
}
string math_formula() const override
{
return "T3=T1*scalar";
}
shared_ptr<TF> clone() const override
{
return make_shared<MulScalar<Author>>(*this);
}
int run(shared_ptr<MemBase> mem, string &error) override
{
Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype;
Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype;
if (a_type != c_type)
{
error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type);
return 1;
}
switch (a_type)
{
case Precision::Float64:
tensorfunc::mulscalar<Author, double>(*mem->gettensor<double>(this->args[0].textvalue), this->getvar<double>(1, mem), *mem->gettensor<double>(this->returns[0].textvalue));
break;
case Precision::Float32:
tensorfunc::mulscalar<Author, float>(*mem->gettensor<float>(this->args[0].textvalue), this->getvar<float>(1, mem), *mem->gettensor<float>(this->returns[0].textvalue));
break;
case Precision::Int64:
tensorfunc::mulscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
break;
case Precision::Int32:
tensorfunc::mulscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
break;
case Precision::Int16:
tensorfunc::mulscalar<Author, int16_t>(*mem->gettensor<int16_t>(this->args[0].textvalue), this->getvar<int16_t>(1, mem), *mem->gettensor<int16_t>(this->returns[0].textvalue));
break;
case Precision::Int8:
tensorfunc::mulscalar<Author, int8_t>(*mem->gettensor<int8_t>(this->args[0].textvalue), this->getvar<int8_t>(1, mem), *mem->gettensor<int8_t>(this->returns[0].textvalue));
break;
default:
error = "Unsupported dtype: " + precision_str(a_type);
return 1;
}
return 0;
}
};

}

#endif
40 changes: 29 additions & 11 deletions front/py/deepx/nn/deepxir.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional,Union
import time
from datetime import datetime # 添加datetime模块

class Param:
def __init__(self, value:Optional[Union[str,int,float,list,tuple]], category:str=None,precision:str=None):
if isinstance(value,str):
self._textvalue=value
elif isinstance(value,int) or isinstance(value,float):
self._textvalue=str(value)
elif isinstance(value,list) or isinstance(value,tuple):
self._textvalue='['+' '.join(str(v) for v in value)+']'
else:
raise ValueError(f"Invalid value type: {type(value)}")

self._category=category
self._precision=precision

def __str__(self):
if self._category is not None:
if self._precision is not None:
return f"{self._category}<{self._precision}> {self._textvalue}"
else:
return f"{self._category} {self._textvalue}"
else:
return self._textvalue

class DeepxIR:
def __init__(self,
name:str,
dtype:str,
args: List[str],
returns: List[str],
author:str):
args: List[Param],
returns: List[Param],
author:str=''):
"""
初始化操作节点
Args:
Expand All @@ -17,8 +39,7 @@ def __init__(self,
author: tensorfunc的作者名称,如"miaobyte"
"""

self._name = name
self._dtype = dtype
self._name = name
self._args = args
self._returns = returns
self._author = author
Expand All @@ -28,10 +49,7 @@ def __init__(self,

def __str__(self):
# 函数名部分
if self._dtype == None or self._dtype == '':
parts = [self._name]
else:
parts = [f"{self._name}@{self._dtype}"]
parts = [self._name]

# 处理输入参数部分 - 使用括号和逗号分隔
args_parts = []
Expand Down
4 changes: 2 additions & 2 deletions front/py/deepx/nn/functional/new.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from deepx.tensor import Tensor
from deepx.autograd.graph import Graph
from deepx.nn.deepxir import DeepxIR
from deepx.nn.deepxir import DeepxIR,Param
from deepx.scheduler import send

def newtensor(t:Tensor,name:str=None):
graph = Graph.get_default()
t._graph = graph
t._node=graph.add_tensor(name,t=t)
if t.graph.eager:
ir2=DeepxIR("newtensor", t.dtype, t.shape, [t._node.name])
ir2=DeepxIR("newtensor",[Param(t.shape)], [Param(t._node.name,category='tensor',precision=t.dtype)])
send(ir2)
def copytensor(t:Tensor,out:Tensor):
graph = Graph.get_default()
Expand Down
4 changes: 2 additions & 2 deletions front/py/deepx/nn/functional/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from deepx.scheduler import send

OpNode.register("print")
def printtensor(t:Tensor,format=''):
ir=DeepxIR("print",'', [t.node.name,format], [])
def printtensor(t:Tensor,format='',author='miaobyte'):
ir=DeepxIR("print",[t.node.name,format], [],author)
send(ir)
return ''