From 06a43185d5a09bfb11e181df147bbcfa656d7d76 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Fri, 18 Jul 2025 15:19:53 -0400 Subject: [PATCH] feat: add missing functions for scalar bools --- TensorLib/Bytes.lean | 10 ++++++++++ TensorLib/Tensor.lean | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/TensorLib/Bytes.lean b/TensorLib/Bytes.lean index 9d57cf6..b86039f 100644 --- a/TensorLib/Bytes.lean +++ b/TensorLib/Bytes.lean @@ -33,6 +33,16 @@ instance : ToLEByteArray ByteArray where instance : ToBEByteArray ByteArray where toBEByteArray arr := arr +instance : ToLEByteArray Bool where + toLEByteArray + | false => ByteArray.mk #[0] + | true => ByteArray.mk #[1] + +instance : ToBEByteArray Bool where + toBEByteArray + | false => ByteArray.mk #[0] + | true => ByteArray.mk #[1] + -- We cast between UIntX and ByteArray without going through BitVec, which is represented -- as a Nat at runtime. We'll have tests for these in our ByteArray extension module where we -- have conversions back and forth between LE ByteArrays and UIntX. diff --git a/TensorLib/Tensor.lean b/TensorLib/Tensor.lean index ac90384..b58f78f 100644 --- a/TensorLib/Tensor.lean +++ b/TensorLib/Tensor.lean @@ -325,6 +325,10 @@ def arrayScalarInt (dtype : Dtype) (n : Int) : Err Tensor := do def arrayScalarInt! (dtype : Dtype) (n : Int) : Tensor := get! $ arrayScalarInt dtype n +def arrayScalarBool (b : Bool) : Err Tensor := arrayScalar Dtype.bool (toLEByteArray b) + +def arrayScalarBool! (b : Bool) : Tensor := get! $ arrayScalarBool b + def arrayScalarFloat32 (f : Float32) : Err Tensor := arrayScalar Dtype.float32 (toLEByteArray f) def arrayScalarFloat32! (f : Float32) : Tensor := get! $ arrayScalarFloat32 f @@ -406,6 +410,20 @@ def ofIntList (dtype : Dtype) (ns : List Int) : Err Tensor := do def ofIntList! (dtype : Dtype) (ns : List Int) : Tensor := get! $ ofIntList dtype ns +def ofFloat32List (ns : List Float32) : Err Tensor := do + let dtype := TensorLib.Dtype.float32 + let size := dtype.itemsize + let arr := Tensor.zeros dtype (Shape.mk [ns.length]) + let mut data := arr.data + let mut posn := 0 + for n in ns do + let v <- dtype.byteArrayOfFloat32 n + data := v.copySlice 0 data posn size + posn := posn + size + .ok { arr with data := data } + +def ofFloat32List! (ns : List Float32) : Tensor := get! $ ofFloat32List ns + def getDimIndex (arr : Tensor) (index : DimIndex) : Err ByteArray := if arr.shape.ndim != index.length then .error "getDimIndex: index mismatch" else let offset := Shape.dimIndexToOffset arr.unitStrides index