Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions TensorLib/Bytes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ instance : ToLEByteArray ByteArray where
instance : ToBEByteArray ByteArray where
toBEByteArray arr := arr

instance : ToLEByteArray Bool where
toLEByteArray
| false => ByteArray.mk #[0]
| true => ByteArray.mk #[1]

instance : ToBEByteArray Bool where
toBEByteArray
| false => ByteArray.mk #[0]
| true => ByteArray.mk #[1]

-- We cast between UIntX and ByteArray without going through BitVec, which is represented
-- as a Nat at runtime. We'll have tests for these in our ByteArray extension module where we
-- have conversions back and forth between LE ByteArrays and UIntX.
Expand Down
18 changes: 18 additions & 0 deletions TensorLib/Tensor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def arrayScalarInt (dtype : Dtype) (n : Int) : Err Tensor := do

def arrayScalarInt! (dtype : Dtype) (n : Int) : Tensor := get! $ arrayScalarInt dtype n

def arrayScalarBool (b : Bool) : Err Tensor := arrayScalar Dtype.bool (toLEByteArray b)

def arrayScalarBool! (b : Bool) : Tensor := get! $ arrayScalarBool b

def arrayScalarFloat32 (f : Float32) : Err Tensor := arrayScalar Dtype.float32 (toLEByteArray f)

def arrayScalarFloat32! (f : Float32) : Tensor := get! $ arrayScalarFloat32 f
Expand Down Expand Up @@ -406,6 +410,20 @@ def ofIntList (dtype : Dtype) (ns : List Int) : Err Tensor := do

def ofIntList! (dtype : Dtype) (ns : List Int) : Tensor := get! $ ofIntList dtype ns

def ofFloat32List (ns : List Float32) : Err Tensor := do
let dtype := TensorLib.Dtype.float32
let size := dtype.itemsize
let arr := Tensor.zeros dtype (Shape.mk [ns.length])
let mut data := arr.data
let mut posn := 0
for n in ns do
let v <- dtype.byteArrayOfFloat32 n
data := v.copySlice 0 data posn size
posn := posn + size
.ok { arr with data := data }

def ofFloat32List! (ns : List Float32) : Tensor := get! $ ofFloat32List ns

def getDimIndex (arr : Tensor) (index : DimIndex) : Err ByteArray :=
if arr.shape.ndim != index.length then .error "getDimIndex: index mismatch" else
let offset := Shape.dimIndexToOffset arr.unitStrides index
Expand Down