diff --git a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp index 5487a2a7..7a69c776 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp @@ -235,6 +235,120 @@ namespace deepx::tf return 0; } }; + + template + class Mul : public TF + { + public: + Mul(vector args, vector returns) + { + this->name = "mul"; + this->author = Author::name(); + this->args = args; + this->returns = returns; + } + string math_formula() const override + { + return "T3=T1*T2"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != b_type || a_type != c_type) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " != " + precision_str(c_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + + template + class MulScalar : public TF + { + public: + MulScalar(vector args, vector returns) + { + this->name = "mulscalar"; + this->author = Author::name(); + this->args = args; + this->returns = returns; + } + string math_formula() const override + { + return "T3=T1*scalar"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != c_type) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + } #endif diff --git a/front/py/deepx/nn/deepxir.py b/front/py/deepx/nn/deepxir.py index 6dce38ea..988afdc2 100644 --- a/front/py/deepx/nn/deepxir.py +++ b/front/py/deepx/nn/deepxir.py @@ -1,14 +1,36 @@ -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional,Union import time from datetime import datetime # 添加datetime模块 +class Param: + def __init__(self, value:Optional[Union[str,int,float,list,tuple]], category:str=None,precision:str=None): + if isinstance(value,str): + self._textvalue=value + elif isinstance(value,int) or isinstance(value,float): + self._textvalue=str(value) + elif isinstance(value,list) or isinstance(value,tuple): + self._textvalue='['+' '.join(str(v) for v in value)+']' + else: + raise ValueError(f"Invalid value type: {type(value)}") + + self._category=category + self._precision=precision + + def __str__(self): + if self._category is not None: + if self._precision is not None: + return f"{self._category}<{self._precision}> {self._textvalue}" + else: + return f"{self._category} {self._textvalue}" + else: + return self._textvalue + class DeepxIR: def __init__(self, name:str, - dtype:str, - args: List[str], - returns: List[str], - author:str): + args: List[Param], + returns: List[Param], + author:str=''): """ 初始化操作节点 Args: @@ -17,8 +39,7 @@ def __init__(self, author: tensorfunc的作者名称,如"miaobyte" """ - self._name = name - self._dtype = dtype + self._name = name self._args = args self._returns = returns self._author = author @@ -28,10 +49,7 @@ def __init__(self, def __str__(self): # 函数名部分 - if self._dtype == None or self._dtype == '': - parts = [self._name] - else: - parts = [f"{self._name}@{self._dtype}"] + parts = [self._name] # 处理输入参数部分 - 使用括号和逗号分隔 args_parts = [] diff --git a/front/py/deepx/nn/functional/new.py b/front/py/deepx/nn/functional/new.py index 879eda7d..1cc14dff 100644 --- a/front/py/deepx/nn/functional/new.py +++ b/front/py/deepx/nn/functional/new.py @@ -1,6 +1,6 @@ from deepx.tensor import Tensor from deepx.autograd.graph import Graph -from deepx.nn.deepxir import DeepxIR +from deepx.nn.deepxir import DeepxIR,Param from deepx.scheduler import send def newtensor(t:Tensor,name:str=None): @@ -8,7 +8,7 @@ def newtensor(t:Tensor,name:str=None): t._graph = graph t._node=graph.add_tensor(name,t=t) if t.graph.eager: - ir2=DeepxIR("newtensor", t.dtype, t.shape, [t._node.name]) + ir2=DeepxIR("newtensor",[Param(t.shape)], [Param(t._node.name,category='tensor',precision=t.dtype)]) send(ir2) def copytensor(t:Tensor,out:Tensor): graph = Graph.get_default() diff --git a/front/py/deepx/nn/functional/print.py b/front/py/deepx/nn/functional/print.py index b4c11fb6..2eb2bb25 100644 --- a/front/py/deepx/nn/functional/print.py +++ b/front/py/deepx/nn/functional/print.py @@ -4,8 +4,8 @@ from deepx.scheduler import send OpNode.register("print") -def printtensor(t:Tensor,format=''): - ir=DeepxIR("print",'', [t.node.name,format], []) +def printtensor(t:Tensor,format='',author='miaobyte'): + ir=DeepxIR("print",[t.node.name,format], [],author) send(ir) return ''