Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepxctl/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea
deepxctl
83 changes: 83 additions & 0 deletions deepxctl/cmd/tensor/print.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package tensor

import (
"flag"
"fmt"
"os"

coretensor "github.com/array2d/deepx/deepxctl/tensor"
)

func PrintCmd() {
printCmd := flag.NewFlagSet("print", flag.ExitOnError)
tensorPath := os.Args[0]
if tensorPath == "" {
fmt.Println("请指定文件路径")
printCmd.Usage()
return
}
var err error
var shape coretensor.Shape
shape, err = coretensor.LoadShape(tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
switch shape.Dtype {
case "bool":
var t coretensor.Tensor[bool]
t, err = coretensor.LoadTensor[bool](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
case "int8":
var t coretensor.Tensor[int8]
t, err = coretensor.LoadTensor[int8](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
case "int16":
var t coretensor.Tensor[int16]
t, err = coretensor.LoadTensor[int16](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
case "int32":
var t coretensor.Tensor[int32]
t, err = coretensor.LoadTensor[int32](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
case "int64":
var t coretensor.Tensor[int64]
t, err = coretensor.LoadTensor[int64](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
case "float16":
// var t coretensor.Tensor[float16]
// t, err = coretensor.LoadTensor[float16](tensorPath)
// if err != nil {
// fmt.Println("读取文件失败:", err)
// }
// t.Print()
case "float32":
var t coretensor.Tensor[float32]
t, err = coretensor.LoadTensor[float32](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
case "float64":
var t coretensor.Tensor[float64]
t, err = coretensor.LoadTensor[float64](tensorPath)
if err != nil {
fmt.Println("读取文件失败:", err)
}
t.Print()
}
}
36 changes: 36 additions & 0 deletions deepxctl/cmd/tensor/tensor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package tensor

import (
"fmt"
"os"
)

func PrintUsage() {
fmt.Println("使用方法:")
fmt.Println(" tensor print <文件路径>")
fmt.Println(" tensor help")
}

func Execute() {
if len(os.Args) < 1 {
PrintUsage()
os.Exit(1)
}

subCmd := "help"
if len(os.Args) > 0 {
subCmd = os.Args[0]
}

switch subCmd {
case "print":
os.Args = os.Args[1:]
PrintCmd()
case "help":
PrintUsage()
default:
fmt.Printf("未知的张量命令: %s\n", subCmd)
PrintUsage()
os.Exit(1)
}
}
5 changes: 5 additions & 0 deletions deepxctl/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module github.com/array2d/deepx/deepxctl

go 1.23.2

require gopkg.in/yaml.v2 v2.4.0 // indirect
3 changes: 3 additions & 0 deletions deepxctl/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
64 changes: 64 additions & 0 deletions deepxctl/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package main

import (
"flag"
"fmt"
"os"
"path/filepath"

"github.com/array2d/deepx/deepxctl/cmd/tensor"
)

var version = "0.1.0"

func printUsage() {
execName := filepath.Base(os.Args[0])
fmt.Printf("用法: %s [命令] [参数]\n\n", execName)
fmt.Println("可用命令:")
fmt.Println(" tensor 张量操作相关命令")
fmt.Println(" version 显示版本信息")
fmt.Println(" help 显示帮助信息")
fmt.Println("\n使用 '%s help [命令]' 获取命令的详细信息", execName)
}

func main() {
flag.Usage = printUsage

if len(os.Args) < 2 {
printUsage()
os.Exit(1)
}

// 获取子命令
cmd := os.Args[1]

// 根据子命令执行相应操作
switch cmd {
case "tensor":
// 移除子命令,让子命令处理剩余的参数
os.Args = os.Args[2:]
tensor.Execute()

case "version":
fmt.Printf("deepxctl 版本 %s\n", version)

case "help":
if len(os.Args) > 2 {
helpCmd := os.Args[2]
switch helpCmd {
case "tensor":
tensor.PrintUsage()
default:
fmt.Printf("未知命令: %s\n", helpCmd)
printUsage()
}
} else {
printUsage()
}

default:
fmt.Printf("未知命令: %s\n", cmd)
printUsage()
os.Exit(1)
}
}
28 changes: 28 additions & 0 deletions deepxctl/tensor/fp16.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package tensor

import (
"encoding/binary"
"math"
)

func Byte2ToFloat16(value []byte) float32 {
bits := binary.BigEndian.Uint16(value)
// 这里需要实现float16到float32的转换
// 简化实现,实际项目中需要更完整的实现
sign := float32(1)
if bits&0x8000 != 0 {
sign = -1
}
exp := int((bits & 0x7C00) >> 10)
frac := float32(bits&0x03FF) / 1024.0

if exp == 0 {
return sign * frac * float32(1.0/16384.0) // 非规格化数
} else if exp == 31 {
if frac == 0 {
return sign * float32(math.Inf(1)) // 无穷大
}
return float32(math.NaN()) // NaN
}
return sign * float32(math.Pow(2, float64(exp-15))) * (1.0 + frac) // 规格化数
}
47 changes: 47 additions & 0 deletions deepxctl/tensor/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package tensor

import (
"encoding/binary"
"os"

"gopkg.in/yaml.v2"
)

func LoadShape(filePath string) (shape Shape, err error) {
var shapeData []byte
shapeData, err = os.ReadFile(filePath + ".shape")
if err != nil {
return
}

err = yaml.Unmarshal(shapeData, &shape)
if err != nil {
return
}
return
}
func LoadTensor[T Number](filePath string) (tensor Tensor[T], err error) {

_, err = os.ReadFile(filePath + ".shape")
if err != nil {
return
}
var shape Shape
shape, err = LoadShape(filePath)
if err != nil {
return
}
file, err := os.Open(filePath + ".data")
if err != nil {
return
}
defer file.Close()
data := make([]T, shape.Size)

err = binary.Read(file, binary.LittleEndian, data)
if err != nil {
return
}
tensor = Tensor[T]{Data: data, Shape: shape}
return
}
121 changes: 121 additions & 0 deletions deepxctl/tensor/print.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package tensor

import "fmt"

func (t *Tensor[T]) Range(dimCount int, f func(indices []int)) {
Shape := t.Shape
if dimCount > len(Shape.Shape) {
panic("dimCount exceeds the number of dimensions in the Tensor.")
}

totalSize := 1

// 计算总的循环次数
for i := 0; i < dimCount; i++ {
totalSize *= Shape.At(i)
}
indices := make([]int, dimCount) // 初始化索引向量
// 遍历所有可能的索引组合
for idx := 0; idx < totalSize; idx++ {
// 反算出 indices 数组
idx_ := idx
for dim := dimCount - 1; dim >= 0; dim-- {
indices[dim] = idx_ % Shape.At(dim) // 计算当前维度的索引
idx_ /= Shape.At(dim) // 更新 idx
}
f(indices) // 调用传入的函数
}
}

func AutoFormat(dtype string) string {
switch dtype {
case "bool":
return "%v"
case "int8":
return "%d"
case "int16":
return "%d"
case "int32":
return "%d"
case "int64":
return "%d"
case "float16":
return "%f"
case "float32":
return "%f"
case "float64":
return "%f"
default:
return "%v"
}
}

// Print 打印Tensor的值
func (t *Tensor[T]) Print(format_ ...string) {
Shape := t.Shape
format := AutoFormat(t.Dtype)
if len(format_) > 0 {
format = format_[0]
}
fmt.Print("shape:[")
for i := 0; i < Shape.Dim; i++ {
fmt.Print(Shape.At(i))
if i < Shape.Dim-1 {
fmt.Print(", ")
}
}
fmt.Println("]")
if Shape.Dim == 1 {
fmt.Print("[")
for i := 0; i < Shape.At(0); i++ {
if i > 0 {
fmt.Print(" ")
}
fmt.Printf(format, t.Get(i))
}
fmt.Println("]")
} else if Shape.Dim == 2 {
fmt.Println("[")
for i := 0; i < Shape.At(0); i++ {
fmt.Print(" [")
for j := 0; j < Shape.At(1); j++ {
if j > 0 {
fmt.Print(" ")
}
fmt.Printf(format, t.Get(i, j))
}

fmt.Print("]")
if i < Shape.At(0)-1 {
fmt.Print(",")
}
fmt.Println()
}
fmt.Println("]")
} else {
t.Range(Shape.Dim-2, func(indices []int) {
fmt.Print(indices)
m, n := Shape.At(Shape.Dim-2), Shape.At(Shape.Dim-1)
fmt.Print([]int{m, n})
fmt.Println("=")

fmt.Println("[")
for i := 0; i < m; i++ {
fmt.Print(" [")
for j := 0; j < n; j++ {
if j > 0 {
fmt.Print(" ")
}
fmt.Printf(format, t.Get(append(indices, i, j)...))
}

fmt.Print("]")
if i < m-1 {
fmt.Print(",")
}
fmt.Println()
}
fmt.Println("]")
})
}
}
Loading