From c98bd82e6758b4250f97f65634c8f3651c984ec4 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Mon, 7 Jul 2025 12:49:41 -0400 Subject: [PATCH 01/15] [feat] Add HLR IR and compilation from HLO to HLR (#194) * Add Sherloc as a dependency * Add the definition of the HLR AST * Add functionality to compile HLO to HLR * Add functionality to export an HLR program as a graph * Add functionality to export an HLR program as a Python program with equivalent semantics * Add a CLI command to compile an HLO program to HLR --- .gitignore | 1 + KLR/HLR.lean | 1 + KLR/HLR/AST.lean | 281 +++++++++++++++++++++++++ KLR/HLR/Basic.lean | 4 + KLR/HLR/Compile.lean | 466 +++++++++++++++++++++++++++++++++++++++++ KLR/HLR/Dot.lean | 122 +++++++++++ KLR/HLR/Py.lean | 155 ++++++++++++++ KLR/lake-manifest.json | 38 +++- KLR/lakefile.lean | 3 + Main.lean | 39 ++++ lake-manifest.json | 10 + lakefile.lean | 3 + 12 files changed, 1119 insertions(+), 4 deletions(-) create mode 100644 KLR/HLR.lean create mode 100644 KLR/HLR/AST.lean create mode 100644 KLR/HLR/Basic.lean create mode 100644 KLR/HLR/Compile.lean create mode 100644 KLR/HLR/Dot.lean create mode 100644 KLR/HLR/Py.lean diff --git a/.gitignore b/.gitignore index 022ef304..b8107b32 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ build/ dist/ klr.bin /.wheel/ +env diff --git a/KLR/HLR.lean b/KLR/HLR.lean new file mode 100644 index 00000000..4ce2b3ec --- /dev/null +++ b/KLR/HLR.lean @@ -0,0 +1 @@ +import KLR.HLR.Basic diff --git a/KLR/HLR/AST.lean b/KLR/HLR/AST.lean new file mode 100644 index 00000000..73a1608a --- /dev/null +++ b/KLR/HLR/AST.lean @@ -0,0 +1,281 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.Core.Operators +import KLR.Util +import SHerLOC +import TensorLib.Dtype +import TensorLib.Shape +import TensorLib.Slice +import TensorLib.Tensor + +open TensorLib (Shape Dtype) + +/- +The definition of the High-Level Representation (HLR) IR. The goal of this IR is to +be a uniform representation for graphs of tensor operations, which we can use as a +common compilation target for different frontends (e.g. StableHLO, PyTorch FX, etc.). +A HLR program consists of a list of functions, each with a name, and input and output tensors. +The function body is in SSA, with each operation producing a single output tensor. +-/ +namespace KLR.HLR + +structure TensorTy where + shape : Shape + dtype : Dtype +deriving Inhabited, Repr, Nonempty + +abbrev Var := String + +-- scalar-scalar binary operators +inductive BinaryOp where + | add + | sub + | mul + | div + | and + | max + | cmp +deriving Inhabited, Repr + +-- scalar unary operators +inductive UnaryOp where + | exp + | sqrt + | neg + | convert (dtype : TensorLib.Dtype) +deriving Inhabited, Repr + +/- +Operators in the HLR (High-Level Representation) of KLR. + +Note: some HLO operations have "load-bearing" output shapes, meaning the +output shape is a vital part of the operation's semantics (e.g. `reshape`). +For these operators, we store the output shape in the `Operator`, even +though this means that when considering an `Operator` as part of a `Statement`, +the output shape information exists in two redundant places: in the `Statement` +and in the `Operator`. +-/ +inductive Operator where + -- An argument to the function, identified by its index. + | arg (index : Nat) + + -- apply a binary operator element-wise to two tensors + | binaryOp (op : BinaryOp) (a b : Var) + -- apply a unary operator element-wise to a tensor + | unaryOp (op : UnaryOp) (a : Var) + -- apply a reduction operation to a tensor, reducing it along the specified dimensions + | reductionOp (op : BinaryOp) (a b : Var) (dim : List Nat) + + -- perform a batch matrix multiplication on two tensors. + -- Specifically, computes the einsum bij,bkj->bik + | batchMatmul (a b : Var) + -- create a tensor with a range of values within the given limits and with the specified stride + | arange (start : Nat) (stop : Nat) (step : Nat) (shape : Shape) + -- concatenate a list of tensors along the specified dimension + | concat (tensors : List Var) (dim : Nat) + -- select elements from two tensors based on a condition tensor + | select (cond a b : Var) + -- create a tensor filled with a specific value, with the given shape + | full (value : Float32) (shape : Shape) + -- transpose a tensor with the provided permutation of dimensions + | transpose (a : Var) (dims : List Nat) + -- unused + | split_with_sizes (a : Var) (sizes : List Nat) -- ?? + -- reshape a tensor to the specified shape + | reshape (a : Var) (shape : Shape) + -- broadcast a tensor to the specified shape + -- TODO: broadcasting is very complicated and we haven't figured it out yet, + -- so this instruction just passes through the semantics of HLO's broadcasting + | broadcast (a : Var) (shape : Shape) (broadcastDims : List Nat) + -- create a constant tensor with the given values and shape + | const (values : TensorLib.Tensor) (shape : Shape) (dtype : TensorLib.Dtype) + -- gather elements from a tensor using the provided indices and offset dimensions + -- TODO: gather is complicated and not used except for in llama, so for now + -- we just pass through the semantics of HLO's gather + | gather (input indices : Var) (offsetDims collapsedSliceDims startIndexMap : List Nat) (indexVectorDim : Nat) + -- slice a tensor along specified dimensions, with start, limit, and stride + | slice (a : Var) (slice : List TensorLib.Slice) + -- call another function, passing input values and receiving outputs + | call (callee : String) (inputValues : List Var) +deriving Inhabited, Repr + +/- +A statement in HLR (High Level Representation). +In SSA form, so each variable is assigned exactly once. +-/ +inductive Statement where + -- A comment in the code, for making the dumped IR readable + | comment (msg : String) + /- + Assign the result of `op` to `dest` , with resulting shape `shape` + + Note: We store the shape directly, even though it is inferrable based on the, + operator, to avoid having to recompute it with fallible operations later. + -/ + | assign (dest : Var) (op : Operator) (shape : TensorTy) + -- Return variables `vars` from the function + | ret (vars : List Var) +deriving Inhabited, Repr + +/- +An HLR function. Note that arguments are referred to by index, so +we only store the argument shapes, not names. +-/ +structure Function where + name : String + inputs : List TensorTy + outputs : List TensorTy + statements : List Statement +deriving Inhabited, Repr, Nonempty + +-- An HLR program +structure Program where + functions : List Function +deriving Inhabited, Repr, Nonempty + +-- Returns the list of variables that this operator immediately depends on. +def dependencies : Operator → List Var + | .arg _ => [] + | .binaryOp _ a b => [a, b] + | .unaryOp _ a => [a] + | .reductionOp _ a b _ => [a, b] + | .batchMatmul a b => [a, b] + | .arange .. => [] + | .concat tensors _ => tensors + | .select cond a b => [cond, a, b] + | .full .. => [] + | .transpose a _ => [a] + | .split_with_sizes a _ => [a] + | .reshape a _ => [a] + | .broadcast a .. => [a] + | .const .. => [] + | .gather a i .. => [a, i] + | .slice a .. => [a] + | .call _ inputs => inputs + +-- Returns the list of all variables defined in this function. +def vars (f : Function) : List Var := + f.statements.filterMap (fun + | .assign dest .. => .some dest + | _ => .none) + +-- Finds the operator that assigns to a variable in the function. +def findVar (f : Function) (v : Var) : Option Operator := + f.statements.findSome? (fun + | .assign dest op _ => if dest == v then .some op else .none + | _ => .none) + +-- TODO: move these toString instances to the TensorLib repo +instance : ToString TensorLib.Slice where + toString s := + let {start, stop, step, ..} := s + let start := start.map toString |>.getD "" + let stop := stop.map toString |>.getD "" + let step := step.map toString |>.getD "" + s!"{start}:{stop}:{step}" + +instance : ToString TensorLib.Shape where + toString s := + s.val.map toString |> "x".intercalate |> fun x => s!"[{x}]" + +instance : ToString TensorLib.Dtype where + toString + | .bool => "bool" + | .int8 => "i8" + | .int16 => "i16" + | .int32 => "i32" + | .int64 => "i64" + | .uint8 => "u8" + | .uint16 => "u16" + | .uint32 => "u32" + | .uint64 => "u64" + | .float32 => "f32" + | .float64 => "f64" + +instance : ToString BinaryOp where + toString + | .add => "add" + | .sub => "sub" + | .mul => "mul" + | .div => "div" + | .and => "and" + | .max => "max" + | .cmp => "cmp" + +instance : ToString UnaryOp where + toString + | .exp => "exp" + | .sqrt => "sqrt" + | .neg => "neg" + | .convert dtype => s!"convert_{dtype}" + +instance : ToString TensorTy where + toString (t : TensorTy) : String := + s!"{t.shape}{t.dtype}" + +instance : ToString Operator where + toString + | .arg n => s!"arg({n})" + | .binaryOp binOp a b => s!"{binOp}({a}, {b})" + | .unaryOp unOp a => s!"{unOp}({a})" + | .reductionOp redOp a b dim => s!"reduce-{redOp}({a}, {b}, dim={dim})" + | .batchMatmul a b => s!"matmul({a}, {b})" + | .arange start stop step shape => s!"arange({start}, {stop}, {step}, shape={shape})" + | .concat tensors dim => s!"concat({", ".intercalate tensors}, dim={dim})" + | .select cond a b => s!"select({cond}, {a}, {b})" + | .full v shape => s!"full({repr v}, shape={shape})" + | .transpose a dims => s!"transpose({a}, dims={dims})" + | .split_with_sizes a sizes => s!"split_with_sizes({a}, sizes={sizes})" + | .reshape a shape => s!"reshape({a}, shape={shape})" + | .broadcast a shape dims => s!"broadcast({a}, shape={shape}, dims={dims})" + | .const t shape dtype => s!"const({repr t}, shape={shape}, dtype={dtype})" + | .gather a indices offsetDims collapsedSliceDims startIndexMap indexVectorDim + => s!" gather({a}, indices={indices}, offsetDims={offsetDims}, collapsedSliceDims={collapsedSliceDims}, startIndexMap={startIndexMap}, indexVectorDim={indexVectorDim})" + | .slice a slices => s!"slice({a}, {slices})" + | .call callee inputValues => + let inputsStr := inputValues.map toString |> ", ".intercalate + s!"call({callee}, inputs=[{inputsStr}])" + +instance : ToString Statement where + toString + | .comment msg => s!"# {msg}" + | .assign dest op shape => s!"{dest} : {toString shape} = {op}" + | .ret name => s!"return {name}" + +instance : ToString Function where + toString f := + let inputsStr := f.inputs.map toString |> ", ".intercalate + let outputsStr := f.outputs.map toString |> ", ".intercalate + let statementsStr := f.statements.map toString |> "\n".intercalate + s!"def {f.name}({inputsStr}) -> ({outputsStr}):\n{statementsStr}" + +instance : ToString Program where + toString p := + let functionsStr := p.functions.map toString |> "\n".intercalate + s!"# Program\n" ++ functionsStr + +-- Human readable name for the operator. +def opName : Operator → String + | .arg _ => s!"arg" + | .binaryOp binOp .. => s!"{binOp}" + | .unaryOp unOp .. => s!"{unOp}" + | .reductionOp redOp .. => s!"{redOp}" + | .batchMatmul .. => s!"batchMatmul" + | .arange .. => s!"arange" + | .concat .. => s!"concat" + | .select .. => s!"select" + | .full .. => s!"full" + | .transpose .. => s!"transpose" + | .split_with_sizes .. => s!"split_with_sizes" + | .reshape .. => s!"reshape" + | .broadcast .. => s!"broadcast" + | .const .. => s!"const" + | .gather .. => s!"gather" + | .slice .. => s!"slice" + | .call callee .. => s!"call {callee}" + +end KLR.HLR diff --git a/KLR/HLR/Basic.lean b/KLR/HLR/Basic.lean new file mode 100644 index 00000000..609cfa1b --- /dev/null +++ b/KLR/HLR/Basic.lean @@ -0,0 +1,4 @@ +import KLR.HLR.AST +import KLR.HLR.Compile +import KLR.HLR.Dot +import KLR.HLR.Py diff --git a/KLR/HLR/Compile.lean b/KLR/HLR/Compile.lean new file mode 100644 index 00000000..5ee6965d --- /dev/null +++ b/KLR/HLR/Compile.lean @@ -0,0 +1,466 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.Core.Operators +import KLR.HLR.AST +import KLR.Util +import SHerLOC +import TensorLib.Shape +import TensorLib.Slice +import TensorLib.Tensor + +open TensorLib (Dtype Shape Tensor) + +-- This module compiles a StableHLO program into an HLR program. +namespace KLR.HLR.Compile + +-- Context for the compilation process, to be stored in a state monad. +structure Ctx where + -- the program being compiled + program : Program + -- the log of messages generated during compilation (for debugging) + log : List String +deriving Inhabited, Repr + +-- Compilation requires tracking state and also potentially returning errors. +abbrev Compile T := StM Ctx T + +-- Emit a message to the compilation log. +def log (msg : String) : Compile Unit := + modify (fun ctx => { ctx with log := ctx.log ++ [msg]}) + +-- Add a function to the program being compiled. +def addFunction (func : Function) : Compile Unit := do + modify (fun ctx => + { ctx with program := { ctx.program with functions := ctx.program.functions ++ [func] } }) + +-- Permute `l` according to the indices in `permutation`. +def permute {T : Type} (l : List T) (permutation : List Nat) : Option (List T) := + permutation.mapM fun dim => l[dim]? + +/- +Parses a StableHLO floatliteral to a Float32. + +TODO: Probably has all sorts of rounding errors, but dealing with floats in Lean +is so agonizing, and the semantics of this number storage system are so confusing, +that I'm satisfied with this as a first pass for now. + +Examples: +The FloatLiteral + intPart: {.minus, 4} + fracPart: {.plus, 785980} + sciPart: {.minus, 1} +Represents the number 4.4785980e-1, which is + (4 + 0.4785980) * 10 ^ -1 + +Alternatively, if we have + intPart: {.minus, 3} + fracPart: {.plus, 597620} + sciPart: {.minus, 3} +Then it represents the number -3.597620e-3, which is + (-3 + 0.597620) * 10 ^ -3 +-/ +def parseFloat (c : StableHLO.Parsing.FloatLiteral) : Float32 := + match c with + | .decimal { + integerPart := ⟨ intSign, intDecimal ⟩, + fractionalPart := ⟨ _, fracDecimal ⟩, + scientificPart := ⟨ sciSign, sciDecimal ⟩ + } => + let exponent := match sciSign with + | .plus => Int.ofNat sciDecimal + | .minus => -1 * Int.ofNat sciDecimal + let sign := match intSign with + | .plus => 1 + | .minus => (-1 : Float32) + let mantissa := String.toNat! s!"{intDecimal}{fracDecimal}" + let fracDigits := ToString.toString fracDecimal |>.length |> Int.ofNat + let exponent := exponent - fracDigits + let (exponentSign, exponent) := if exponent >= 0 then + (false, exponent.natAbs) + else + (true, exponent.natAbs) + sign * OfScientific.ofScientific mantissa exponentSign exponent + | .hexaDecimal _ => panic! "Hexadecimal float literals are not supported yet." + +#guard parseFloat (.decimal { integerPart := ⟨ .plus, 4 ⟩, fractionalPart := ⟨ .plus, 785980 ⟩, scientificPart := ⟨ .plus, 3 ⟩}) == 4785.980 +#guard parseFloat (.decimal { integerPart := ⟨ .minus, 4 ⟩, fractionalPart := ⟨ .plus, 785980 ⟩, scientificPart := ⟨ .minus, 1 ⟩}) == -0.4785980 +#guard parseFloat (.decimal { integerPart := ⟨ .minus, 3 ⟩, fractionalPart := ⟨ .plus, 597620 ⟩, scientificPart := ⟨ .minus, 3 ⟩}) == -0.003597620 + +-- Parse a StableHLO element literal to a Float32. +def parseFloatFromElementLiteral (c : StableHLO.Parsing.ElementLiteral) : Compile Float32 := + match c with + |(.floatLiteral f) => pure (parseFloat f) + | _ => throw s!"Expected a float literal, but got {repr c}." + +-- Convert a list of Float32 values to a TensorLib tensor. +def ofFloatList (ns : List Float32) : Compile Tensor := do + let dtype := TensorLib.Dtype.float32 + let size := dtype.itemsize + let arr := Tensor.zeros dtype (Shape.mk [ns.length]) + let mut data := arr.data + let mut posn := 0 + for n in ns do + let v <- dtype.byteArrayOfFloat32 n + data := v.copySlice 0 data posn size + posn := posn + size + .ok { arr with data := data } + +-- Convert a list of tensors to a single tensor by concatenating them along a new first dimension. +def ofTensorList (ns : List Tensor) : Compile Tensor := + match ns with + | f :: r => do + if ! (r.all fun t => t.shape == f.shape) then + throw s!"All tensors in the list must have the same shape, but got {repr ns}." + else + let newShape := Shape.mk (ns.length :: f.shape.val) + let dtype := f.dtype + let size := ns.foldl (fun acc t => acc + t.size) 0 + let arr := Tensor.zeros dtype newShape + let mut data := arr.data + let mut posn := 0 + for t in ns do + data := t.data.copySlice 0 data posn (t.data.size * size) + posn := posn + (t.data.size * size) + .ok { arr with data := data } + | [] => pure (TensorLib.Tensor.empty TensorLib.Dtype.float32 (Shape.mk [])) + +-- Parse a StableHLO dense literal to an HLR tensor. +def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor + | .denseDimension values => do + ofTensorList (← values.mapM parseTensorLiteral) + | .denseElements elems => do + (← elems.mapM parseFloatFromElementLiteral) |> ofFloatList + +-- Convert a StableHLO tensor type to an HLR TensorTy. +def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do + let shape ← t.shape.mapM (fun + | .known d => pure d + | .unknown => throw "Can't support tensors with dynamic shape") + let dtype ← match t.tensorElementTypeGen with + | .classic (.floatType .f32) => pure TensorLib.Dtype.float32 + | _ => throw s!"Unsupported tensor element type: {repr t.tensorElementTypeGen}" + pure (.mk (.mk shape) dtype) + +-- Parse an HLR TensorTy at index `n` from the list of types. +def parseTensorTypeFromValueTypes (l : List StableHLO.Parsing.ValueType) (n : Nat): Compile TensorTy := + match l[n]? with + | .some (.tensorType t) => parseTensorType t + | .some t => throw s!"Element {n} of type list must have tensor type, but got {repr t}." + | _ => throw s!"Type list must have at least {n + 1} values, but got only {l.length}." + +-- Parse an HLR TensorTy from a list of types, expecting the list to have exactly one element. +def parseSingleTensorTypeFromValueTypes : List StableHLO.Parsing.ValueType → Compile TensorTy + | [.tensorType t] => parseTensorType t + | t => throw s!"Expected type list to have a single tensor type, but got {repr t}." + +-- Parse an array from a StableHLO literal. +def parseArray (c : StableHLO.Parsing.Literal) : Compile (List Nat) := + match c with + | .array (.array64 arr) => pure (arr.map fun ⟨ _sign, n ⟩ => n) + | .array (.array1 _) => throw "array1 unimplemented." + | _ => throw "Expected an array of integers." + +/- +Parse a Nat from a StableHLO float literal. +We need this because integers are often represented as floats in StableHLO. +-/ +def parseNatFromFloat (c : StableHLO.Parsing.Literal) : Compile Nat := + match c with + | .element (.floatLiteral (.decimal {integerPart, fractionalPart, scientificPart})) => + match (fractionalPart.decimal == 0, scientificPart.decimal == 0, integerPart.sign) with + | (true, true, .plus) => pure integerPart.decimal + | (false, _, _) | (_, false, _) => + throw s!"Expected a non-negative integer, but got a float literal with fractional or scientific part: {repr c}." + | (_, _, .minus) => + throw s!"Expected a non-negative integer, but got a float literal with negative sign: {repr c}." + | .element (.floatLiteral l) => throw s!"Got unsupported float literal {repr l}." + | l => throw s!"Expected a float literal but got {repr l}." + +-- Find an attribute by name in a list of attributes +def lookupAttribute (attrs : List StableHLO.Parsing.Attribute) (name : String) : Compile StableHLO.Parsing.Constant := + match attrs.find? (fun ⟨ id, _ ⟩ => id == name) with + | some ⟨ _, attr ⟩ => pure attr + | none => throw s!"Attribute '{name}' not found." + +-- Find an attribute by name in a list of attributes, returning only the associated literal, not its type +def lookupAttributeValue (attrs : List StableHLO.Parsing.Attribute) (name : String) : Compile StableHLO.Parsing.Literal := + lookupAttribute attrs name |>.map (fun ⟨ lit, _ ⟩ => lit) + +-- Get the value of a field in a StableHLO record, expecting it to be a list of integers. +def lookupNatsInFields (fields : List StableHLO.Parsing.StableHLORecordField) (name : String) : Compile (List Nat) := + match fields.find? (fun ⟨ n, _ ⟩ => n == name) with + | some (.mk _ (.many ns)) => pure ns + | some v => throw s!"Field '{name}' must be a list of integers, but got {repr v}." + | none => pure [] + +-- Get the value of a field in a StableHLO record, expecting it to be a single integer. +def lookupNatInFields (fields : List StableHLO.Parsing.StableHLORecordField) (name : String) : Compile Nat := + match fields.find? (fun ⟨ n, _ ⟩ => n == name) with + | some (.mk _ (.one n)) => pure n + | some v => throw s!"Field '{name}' must be a single integer, but got {repr v}." + | none => throw s!"Field '{name}' not found in record list {repr fields}." + +-- extract the arguments to the `dotGeneral` operation from a record in the list of attributes +def extractDotDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile (List Nat × List Nat × List Nat × List Nat) := do + let dotAttr ← lookupAttributeValue attrs "dot_dimension_numbers" + match dotAttr with + | .stableHLORecord fields => + let lhs_batching_dims ← lookupNatsInFields fields "lhs_batching_dimensions" + let lhs_contracting_dims ← lookupNatsInFields fields "lhs_contracting_dimensions" + let rhs_batching_dims ← lookupNatsInFields fields "rhs_batching_dimensions" + let rhs_contracting_dims ← lookupNatsInFields fields "rhs_contracting_dimensions" + pure (lhs_batching_dims, lhs_contracting_dims, rhs_batching_dims, rhs_contracting_dims) + | _ => throw "Attribute 'dot_dimension_numbers' must be a stableHLORecord." + +-- extract the arguments to the `gather` operation from a record in the list of attributes +def extractDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile (List Nat × List Nat × List Nat × Nat) := do + let attr ← lookupAttributeValue attrs "dimension_numbers" + match attr with + | .stableHLORecord fields => + let offset_dims ← lookupNatsInFields fields "offset_dims" + let collapsed_slice_dims ← lookupNatsInFields fields "collapsed_slice_dims" + let start_index_map ← lookupNatsInFields fields "start_index_map" + let index_vector_dim ← lookupNatInFields fields "index_vector_dim" + pure (offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim) + | _ => throw "Attribute 'dimension_numbers' must be a stableHLORecord." + +/- +The StableHLO `reduce` operation always calls an arbitrary reduction function. +However, in HLR we only support a few specific reduction operations (mostly +arithmetic and logical binary operators). Since many StableHLO programs only +use these basic reduction operations, we can recognize when the StableHLO function +called by a `reduce` operation is one of these basic operations, and convert it +to the corresponding HLR BinaryOp. +If this process is unsuccessful, it means that the input `reduce` function is more +complicated and can't be supported by the current HLR design. +-/ +def reduceFunctionToReduceOp (f : StableHLO.Parsing.InputFunc) : Compile (BinaryOp) := do + match f with + | .mk _ [.stablehlo .maximum .., .return ..] => pure .max + | .mk _ [.stablehlo .add .., .return ..] => pure .add + | .mk _ [.stablehlo .and .., .return ..] => pure .and + | .mk _ [.stablehlo op .., .return ..] => throw s!"Unimplemented reduction function {repr op}." + | op => + throw ("Unable to recognize `reduce` function as simple binary operator. Compiling" ++ + "this program likely requires adding support for arbitrary function calling in `reduce`" + ++ s!"Function: {repr op}") + +/- +Compile a StableHLO operation into a list of HLR statements. + +Note: this function annotates each statement with the type of its output, +but this type is merely passed through from the HLO program, not computed anew. +This means it's possible that if there's a mistake in the shape calculation +in the HLO program, the HLR statements will also have incorrect shapes. +Eventually, we'll want a function that can shape-check an HLR program. +-/ +def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) + | .stablehlo opCode inputValues inputFunctions inputAttributes outputs signature => do + -- Reuse the variable names and shapes from the StableHLO program + let output ← match outputs with + | [output] => pure output + | _ => throw "Operator signature must have a single output." + let outputTy ← parseSingleTensorTypeFromValueTypes signature.range + -- helper function to emit HLR for element-wise unary ops + let makeUnOp := fun (op : UnaryOp) => do + log s!"Compiling unary op {op}" + let a := inputValues[0]! + pure [.assign output (.unaryOp op a) outputTy] + -- helper function to emit HLR for element-wise binary ops + let makeBinOp := fun (op : BinaryOp) => do + log s!"Compiling binary op {op}" + let a := inputValues[0]! + let b := inputValues[1]! + pure [.assign output (.binaryOp op a b) outputTy] + match opCode with + -- element-wise unary operators + | .sqrt => makeUnOp .sqrt + | .negate => makeUnOp .neg + | .exponential => makeUnOp .exp + | .convert => makeUnOp (UnaryOp.convert outputTy.dtype) + -- element-wise binary operators + | .compare => makeBinOp .cmp + | .multiply => makeBinOp .mul + | .add => makeBinOp .add + | .and => makeBinOp .and + | .maximum => makeBinOp .max + | .subtract => makeBinOp .sub + | .divide => makeBinOp .div + -- tensor nullary operators + | .constant => do + log "Compiling constant operation" + let valueAttr ← lookupAttributeValue inputAttributes "value" + match valueAttr with + | (.tensor (.denseElements [(.floatLiteral f)])) => + pure [.assign output (.full (parseFloat f) outputTy.shape) outputTy] + | (.tensor lit) => + pure [.assign output (.const (← parseTensorLiteral lit) outputTy.shape outputTy.dtype) outputTy] + | _ => throw "Constant operation requires a 'value' attribute with tensor literal." + -- tensor unary operators + | .reshape => do + log "reshape" + let input := inputValues[0]! + pure [.assign output (.reshape input outputTy.shape) outputTy] + | .gather => + log "Compiling gather operation" + let (offsetDims, collapsedSliceDims, startIndexMap, indexVectorDim) ← + extractDimensionNumbers inputAttributes + let input := inputValues[0]! + let indices := inputValues[1]! + pure [.assign output (.gather input indices offsetDims collapsedSliceDims startIndexMap indexVectorDim) outputTy] + | .slice => + log "Compiling slice operation" + let input := inputValues[0]! + let start ← (← lookupAttributeValue inputAttributes "start_indices") |> parseArray + let limit ← (← lookupAttributeValue inputAttributes "limit_indices") |> parseArray + let stride ← (← lookupAttributeValue inputAttributes "strides") |> parseArray + let slices ← start + |> List.length + |> List.range + |> List.mapM (fun i => + TensorLib.Slice.make + (.some $ Int.ofNat start[i]!) + (.some $ Int.ofNat limit[i]!) + (.some $ Int.ofNat stride[i]!)) + pure [.assign output (.slice input slices) outputTy] + | .reduce => + log "Compiling reduce operation" + let op ← reduceFunctionToReduceOp inputFunctions[0]! + let dims ← (← lookupAttributeValue inputAttributes "dimensions") |> parseArray + pure [.assign output (.reductionOp op inputValues[0]! inputValues[1]! dims) outputTy] -- TODO: init value + | .broadcastInDim => do + log "Compiling broadcastInDim operation" + let input := inputValues[0]! + let broadcastDims ← (← lookupAttributeValue inputAttributes "broadcast_dimensions") |> parseArray + pure [.assign output (.broadcast input outputTy.shape broadcastDims) outputTy] + | .transpose => do + log "Compiling transpose operation" + let input := inputValues[0]! + let dims ← (← lookupAttributeValue inputAttributes "permutation") |> parseArray + pure [.assign output (.transpose input dims) outputTy] + -- tensor binary operators + | .dotGeneral => do + /- + The semantics of the `dotGeneral` operation are complex, see the + specification for details. The variables here are named similar to the + variables in the spec to aid comprehension. + https://github.com/openxla/stablehlo/blob/6f7b4ab8f96dc65cf3c8e9824836117d2934cc45/docs/spec.md?#dot_general + -/ + log "Compiling dotGeneral operation" + -- Gather metadata from the inputs + let (lhsBatchingDims, lhsContractingDims, rhsBatchingDims, rhsContractingDims) ← + extractDotDimensionNumbers inputAttributes + let lhs := inputValues[0]! + let rhs := inputValues[1]! + let lhsType ← parseTensorTypeFromValueTypes signature.domain 0 + let rhsType ← parseTensorTypeFromValueTypes signature.domain 1 + let dtype := lhsType.dtype + let lhsShape := lhsType.shape + let rhsShape := rhsType.shape + let lhsDims := List.range (TensorLib.Shape.ndim lhsShape) + let rhsDims := List.range (TensorLib.Shape.ndim rhsShape) + -- Calculate shapes of intermediate tensors and output + let lhsResultDims := lhsDims.filter (fun i => !lhsBatchingDims.contains i && !lhsContractingDims.contains i) + let rhsResultDims := rhsDims.filter (fun i => !rhsBatchingDims.contains i && !rhsContractingDims.contains i) + let lhsTransposePerm := lhsBatchingDims ++ lhsResultDims ++ lhsContractingDims + let rhsTransposePerm := rhsBatchingDims ++ rhsResultDims ++ rhsContractingDims + let lhsTransposedShape := permute lhsShape.val lhsTransposePerm |>.get! + let rhsTransposedShape := permute rhsShape.val rhsTransposePerm |>.get! + let batchShape := lhsTransposedShape.take lhsBatchingDims.length + let lhsResultShape := lhsTransposedShape.drop lhsBatchingDims.length |>.take lhsResultDims.length + let rhsResultShape := rhsTransposedShape.drop rhsBatchingDims.length |>.take rhsResultDims.length + let contractingShape := lhsTransposedShape.drop (lhsBatchingDims.length + lhsResultDims.length) |> + List.take (lhsTransposedShape.length - (lhsBatchingDims.length + lhsResultDims.length)) + let batchSize := if batchShape.isEmpty then 1 else batchShape.foldl (· * ·) 1 + let resultShape := batchShape ++ lhsResultShape ++ rhsResultShape + let lhsResultSize := if lhsResultShape.isEmpty then 1 else lhsResultShape.foldl (· * ·) (1 : Nat) + let rhsResultSize := if rhsResultShape.isEmpty then 1 else rhsResultShape.foldl (· * ·) (1 : Nat) + let contractingSize := if contractingShape.isEmpty then 1 else contractingShape.foldl (· * ·) 1 + -- Create fresh variable names for intermediate results + -- TODO: this is currently not correct, since the names are not unique + let lhsTransposedName := lhs ++ "_transposed" + let rhsTransposedName := rhs ++ "_transposed" + let lhsReshapedName := lhs ++ "_reshaped" + let lhsReshapedShape := [batchSize, lhsResultSize, contractingSize] + let lhsReshapedTy := TensorTy.mk (.mk lhsReshapedShape) dtype + let rhsReshapedName := rhs ++ "_reshaped" + let rhsReshapedShape := [batchSize, rhsResultSize, contractingSize] + let rhsReshapedTy := TensorTy.mk (.mk rhsReshapedShape) dtype + let resultReshapedName := output ++ "_reshaped" + let resultReshapedShape := [batchSize, lhsResultSize, rhsResultSize] + let resultReshapedType := TensorTy.mk (.mk resultReshapedShape) dtype + -- Emit the HLR statements for the dotGeneral operation + pure ([ + .comment "Dot General Operation", + .assign lhsTransposedName (.transpose lhs lhsTransposePerm) (.mk (.mk lhsTransposedShape) dtype), + .assign rhsTransposedName (.transpose rhs rhsTransposePerm) (.mk (.mk rhsTransposedShape) dtype), + .assign lhsReshapedName (.reshape lhsTransposedName lhsReshapedTy.shape) lhsReshapedTy, + .assign rhsReshapedName (.reshape rhsTransposedName rhsReshapedTy.shape) rhsReshapedTy, + .assign resultReshapedName (.batchMatmul lhsReshapedName rhsReshapedName) resultReshapedType, + .assign output (.reshape resultReshapedName (.mk resultShape)) outputTy, + ]) + | .concatenate => + log "Compiling concatenate operation" + let tensors := inputValues + let dim ← (← lookupAttributeValue inputAttributes "dimension") |> parseNatFromFloat + pure [.assign output (.concat tensors dim) outputTy] + -- tensor ternary operators + | .select => + log "Compiling select operation" + let cond := inputValues[0]! + let a := inputValues[1]! + let b := inputValues[2]! + pure [.assign output (.select cond a b) outputTy] + | _ => throw s!"Unsupported HLO operation: {repr opCode}" + | .return ops _ => do + log "Compiling return operation" + pure [Statement.ret ops] + | .call callee inputValues outputs signature => do + log "Compiling call operation" + let output ← match outputs with + | [output] => pure output + | _ => throw "Call operator signature must have a single output." + pure [ + Statement.assign + output + (.call callee inputValues) + (← parseSingleTensorTypeFromValueTypes signature.range)] + + | s => throw s!"Unsupported operation type {repr s}" + +def compileFunc (f : StableHLO.Parsing.Function) : Compile Unit := do + let .mk args body := f.funcBody + let inputs ← args.mapM (fun ⟨ name, v ⟩ => do + match v with + | .tensorType t => parseTensorType t + | _ => throw s!"Function input {name} must have tensor type.") + let outputs ← f.funcType.range.mapM fun + | .tensorType t => parseTensorType t + | _ => throw "Function output must be a tensor type." + -- Since arguments are referred to by index, emit a statement for each + -- argument that assigns it to a named variable + let preamble ← args.mapIdxM (fun i ⟨ name, v ⟩ => do + match v with + | .tensorType t => + pure (Statement.assign name (.arg i) (← parseTensorType t)) + | _ => throw s!"Function input {name} must have tensor type.") + let statements ← body.flatMapM compileOp + let func := Function.mk f.funcId inputs outputs (preamble ++ statements) + addFunction func + +def compileModule (m : StableHLO.Parsing.Module) : Compile Unit := + m.modFuncs.forM compileFunc + +def compile (m : List StableHLO.Parsing.Module) : (Except String Unit) × Ctx := + let compiled := match m with + | [m] => compileModule m + | _ => throw "Only one module is supported for now." + match compiled.run default with + | .ok _ s => (.ok (), s) + | .error err s => (throw err, s) + +end KLR.HLR.Compile diff --git a/KLR/HLR/Dot.lean b/KLR/HLR/Dot.lean new file mode 100644 index 00000000..f59d890a --- /dev/null +++ b/KLR/HLR/Dot.lean @@ -0,0 +1,122 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.HLR.AST +import SHerLOC.Analysis.Graph + +open StableHLO.Analysis (Graph Edge Vertex) + +-- This module provides a way to convert an HLR function into a DOT graph representation. +namespace KLR.HLR.Graph + +/- +Process the name `var` so that it can used as a node ID in DOT format. +Notably, IDs can't start with a digit, so we prefix it with "node_". +-/ +def sanitize (var : String) : String := + s!"node_{var}" + +def makeReturnNode (funcName : String) : Vertex := + .mk + s!"return_{funcName}" + (.mk [ + ("label", s!"return\\n{funcName}"), + ("shape", "box"), + ("style", "filled"), + ("fillcolor", "lightgray"), + ("color", "gray") + ]) +def makeOpNode (op : Operator) (output : String) : Vertex := + let attrs := match op with + | .arg .. => [ + ("shape", "diamond"), + ("style", "filled"), + ("fillcolor", "lightgreen"), + ("color", "green") + ] + | .batchMatmul .. => [ + ("style", "filled"), + ("fillcolor", "lightpink"), + ("color", "red") + ] + | .slice .. => [ + ("style", "filled"), + ("fillcolor", "lightblue"), + ("color", "blue") + ] + | _ => [] + .mk + (sanitize output) + (.mk ([ + ("label", s!"{opName op}\\n{output}"), + ] ++ attrs)) + +def makeConstNode (name : String) (usedAt : String): Vertex := + .mk + s!"node_const_{name}_{usedAt}" + (.mk [ + ("label", s!"const\\n{name} ({usedAt})"), + ("shape", "diamond"), + ("style", "filled"), + ("fillcolor", "lightyellow"), + ("color", "yellow") + ]) + +def makeEdge (source : String) (dest : String) : Edge := + .mk + source + dest + (.mk []) + +/- +Convert an HLR function to a DOT graph, where each variable is a vertex +and an edge exists from A to B if A is used in the computation of B. + +Note: since constants are reused in many parts of the function, they can +cause the graph to have long edges that cross over other nodes. To avoid this, +we create a separate vertex for each use of a constant. +-/ +def graph (f : HLR.Function) : Graph := Id.run do + let mut vertices := [] + let mut edges := [] + -- Every variables in the function that is the result of a `constant` operatior + let mut consts := f.statements.filterMap (fun + | .assign v (.const _ _ _) _ => .some v + | _ => .none) + -- A closure that creates edges from a list of inputs to an output variable. + -- If the input is a constant, it creates a vertex for that constant. + let (makeEdges : List String → String → (List Vertex) × (List Edge)) := fun inputs output => Id.run do + let mut vertices := [] + let mut edges := [] + for input in inputs do + if consts.contains input then + let node := makeConstNode input output + vertices := node :: vertices + edges := (makeEdge node.id output) :: edges + else + edges := (makeEdge input output) :: edges + return (vertices, edges) + + -- Walk the program statements and create vertices and edges. + for s in f.statements do + match s with + | .assign _ (.const _ _ _) _ => () + | .assign v op _ => + let deps := dependencies op |>.map sanitize + let (newVertices, newEdges) := makeEdges deps (sanitize v) + vertices := [makeOpNode op v] ++ newVertices ++ vertices + edges := newEdges ++ edges + | .ret vars => + let node := makeReturnNode f.name + let deps := vars.map sanitize + let (newVertices, newEdges) := makeEdges deps node.id + vertices := [node] ++ newVertices ++ vertices + edges := newEdges ++ edges + | .comment _ => () + + .mk f.name vertices edges + +end KLR.HLR.Graph diff --git a/KLR/HLR/Py.lean b/KLR/HLR/Py.lean new file mode 100644 index 00000000..d48227a2 --- /dev/null +++ b/KLR/HLR/Py.lean @@ -0,0 +1,155 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.Core.Operators +import KLR.HLR.AST +import KLR.Util +import SHerLOC +import TensorLib.Shape +import TensorLib.Slice + +open Std.Format +open TensorLib (Dtype Shape Slice) + +/- +This module converts an HLR program into a runnable Python program. +At present, it can't convert HLR constants to python constants and can't +take input tensors, so it is only helpful to ensure that the shape +annotations are correct and that the program is well-formed. +-/ +namespace KLR.HLR.Py + +structure FormatCtx where + indent : Nat := 0 + program : String := "" + +abbrev Format := Std.Format + +def dTypeToPy : Dtype → String + | .bool => "np.bool_" + | .int8 => "np.int8" + | .int16 => "np.int16" + | .int32 => "np.int32" + | .int64 => "np.int64" + | .uint8 => "np.uint8" + | .uint16 => "np.uint16" + | .uint32 => "np.uint32" + | .uint64 => "np.uint64" + | .float32 => "np.float32" + | .float64 => "np.float64" + +def binOpToPy : BinaryOp → String + | .add => "np.add" + | .sub => "np.subtract" + | .mul => "np.multiply" + | .div => "np.divide" + | .and => "np.logical_and" + | .max => "np.maximum" + | .cmp => "np.compare" + +def unaryOpToPy : UnaryOp → String + | .exp => "np.exp" + | .sqrt => "np.sqrt" + | .neg => "np.negative" + | .convert d => s!"(lambda x: x.as_type({dTypeToPy d}))" + +def reduceOpToPy : BinaryOp → String + | .add => "np.sum" + | .max => "np.max" + | op => panic! s!"Unsupported reduction operation: {op}" + +def intLitToPy : StableHLO.Parsing.IntegerLiteral → String + | .mk .plus decimal => s!"{decimal}" + | .mk .minus decimal => s!"-{decimal}" + +def floatLitToPy : StableHLO.Parsing.FloatLiteral → String + | .decimal (.mk intPart fracPart sciPart) => + let intPartStr := intLitToPy intPart + let fracPartStr := intLitToPy fracPart + let sciPartStr := intLitToPy sciPart + s!"{intPartStr}.{fracPartStr}e{sciPartStr}" + | .hexaDecimal n => toString n + +def elementLitToPy : StableHLO.Parsing.ElementLiteral → String + | .booleanLiteral .true => "True" + | .booleanLiteral .false => "False" + | .floatLiteral f => floatLitToPy f + | .complexLiteral { real, imaginary } => + s!"complex({floatLitToPy real}, {floatLitToPy imaginary})" + | .stringLiteral str => s!"'{str}'" + +def valueToPy : StableHLO.Parsing.DenseLiteral → String + | .denseDimension n => s!"[{n.map valueToPy |> ", ".intercalate}]" + | .denseElements arr => ",".intercalate (arr.map elementLitToPy) + +def shapeToPy (s : Shape) : String := + s.val.map toString |> ",".intercalate + +def varToPy (arg : Var) : String := + -- Prefix, since Python variables can't start with a digit + s!"var_{arg}" + +def opToPy (op : Operator) : String := + match op with + | .arg index => s!"args[{index}]" + | .binaryOp binOp a b => s!"{binOpToPy binOp}({varToPy a}, {varToPy b})" + | .unaryOp unOp a => s!"{unaryOpToPy unOp}({varToPy a})" + | .reductionOp redOp a b dim => s!"{reduceOpToPy redOp}({varToPy a}, initial={varToPy b}, axis={dim[0]!})" + | .batchMatmul a b => s!"np.einsum(\"bij,bkj->bik\", {varToPy a}, {varToPy b})" + | .arange start stop step shape => s!"np.arange({start}, {stop}, {step}).reshape({shapeToPy shape})" + | .concat tensors dim => + let tensorsStr := String.intercalate "," (tensors.map toString) + s!"np.concatenate([{tensorsStr}], axis={dim})" + | .select cond a b => s!"np.where({cond}, {varToPy a}, {varToPy b})" + | .full value shape => s!"np.full(({shapeToPy shape}), {value})" + | .transpose a dims => + let dimsStr := dims.map toString |> ", ".intercalate + s!"np.transpose({varToPy a}, axes=[{dimsStr}])" + | .split_with_sizes .. => panic! s!"Split with sizes operation not implemented in Python translation" + | .reshape a shape => s!"{varToPy a}.reshape({shapeToPy shape})" + | .broadcast a shape dims => s!"jax.lax.broadcast_in_dim({varToPy a}, ({shapeToPy shape}), {dims})" + | .const _ shape _ => s!"np.random.random(({shapeToPy shape})" -- TODO: make this use the actual constant value + | .gather .. => panic! s!"Gather operation not implemented in Python translation" + | .slice .. => panic! s!"Slice operation not implemented in Python translation" + | .call .. => + panic! s!"Can't translate call operators to Python" + +def compileStatement (s : Statement) : Format := + match s with + | .comment msg => text s!"# {msg}" + | .assign dest op {shape, ..} => + text (s!"{varToPy dest} = {opToPy op} # {shape}") ++ + if shape.ndim != 0 then + line ++ text (s!"assert {varToPy dest}.shape == ({shapeToPy shape},), \"Expected %s, got %s\" % (({shapeToPy shape}), {varToPy dest}.shape)") + else + nil + | .ret vars => + let varNames := vars.map varToPy |> ", ".intercalate + text s!"return {varNames}" + +def compileFunction (f : Function) : Format := + let inputsStr := f.inputs.map (fun {shape, ..} => s!"\"np.ndarray[({shapeToPy shape})]\"") |> ", ".intercalate |> fun x => s!"args: Tuple[{x}]" + let outputsStr := f.outputs.map (fun {shape, ..} => shapeToPy shape) |> ", ".intercalate + let funcHeader := s!"def {f.name}({inputsStr}) -> (\"np.ndarray[({outputsStr})]\"):" + let args := f.inputs.map (fun {shape, ..} => s!"np.random.random(({shapeToPy shape}))") + let funCall := s!"{f.name}([{", ".intercalate args}])" + text funcHeader ++ + (nest 4 (prefixJoin line (f.statements.map compileStatement))) ++ line ++ + text funCall + +def compileProgram (p : Program) : Format := + let lines := [ + text "import numpy as np", + text "import jax", + text "from typing import Tuple"] ++ + p.functions.map compileFunction + joinSep lines line + +-- Compile the HLR program to a Python program. +def compile (p : Program) : String := + (compileProgram p).pretty + +end KLR.HLR.Py diff --git a/KLR/lake-manifest.json b/KLR/lake-manifest.json index 65a9531e..e6ad0719 100644 --- a/KLR/lake-manifest.json +++ b/KLR/lake-manifest.json @@ -1,7 +1,17 @@ {"version": "1.1.0", "packagesDir": ".lake/packages", "packages": - [{"url": "https://github.com/leanprover/TensorLib.git", + [{"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": false, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover/TensorLib.git", "type": "git", "subDir": null, "scope": "", @@ -41,6 +51,16 @@ "inputRev": "v4.21.0", "inherited": false, "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover/lean4-cli.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "02dbd02bc00ec4916e99b04b2245b30200e200d0", + "name": "Cli", + "manifestFile": "lake-manifest.json", + "inputRev": "v4.19.0", + "inherited": true, + "configFile": "lakefile.toml"}, {"url": "https://github.com/leanprover-community/import-graph.git", "type": "git", "subDir": null, @@ -51,12 +71,22 @@ "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover/lean4-cli.git", + {"url": "https://github.com/leanprover-community/aesop", "type": "git", "subDir": null, "scope": "", - "rev": "f9e25dcbed001489c53bceeb1f1d50bbaf7451d4", - "name": "Cli", + "rev": "ddfca7829bf8aa4083cdf9633935dddbb28b7b2a", + "name": "aesop", + "manifestFile": "lake-manifest.json", + "inputRev": "v4.20.0", + "inherited": true, + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/batteries", + "type": "git", + "subDir": null, + "scope": "leanprover-community", + "rev": "7a0d63fbf8fd350e891868a06d9927efa545ac1e", + "name": "batteries", "manifestFile": "lake-manifest.json", "inputRev": "v4.20.0", "inherited": true, diff --git a/KLR/lakefile.lean b/KLR/lakefile.lean index 2150fff4..5fd33e39 100644 --- a/KLR/lakefile.lean +++ b/KLR/lakefile.lean @@ -28,5 +28,8 @@ require plausible from git require TensorLib from git "https://github.com/leanprover/TensorLib.git" @ "v0.0.12" +require SHerLOC from git + "https://github.com/leanprover/SHerLOC.git" @ "c74ae090d4326cca9ff636184c330a67ca039ef6" + -- Comment the above and uncomment this for local development -- require TensorLib from "../../TensorLib" diff --git a/Main.lean b/Main.lean index 9898e56e..d78a7a59 100644 --- a/Main.lean +++ b/Main.lean @@ -18,6 +18,9 @@ import Cli import KLR import TensorLib.Npy import TensorLib.Tensor +import SHerLOC +import KLR.HLR +import SHerLOC.Analysis.Graph open Cli open KLR @@ -298,6 +301,32 @@ def evalKLR (p : Parsed) : IO UInt32 := do TensorLib.Npy.Ndarray.save! arr filename return 0 +def hloToHLR (p : Parsed) : IO UInt32 := do + let file := p.positionalArg! "file" |>.as! String + let s <- IO.FS.readFile file + match StableHLO.Parsing.parse s with + | .ok (hlo, _) => + let hlr := KLR.HLR.Compile.compile hlo + match hlr with + | (.ok _, s) => do + let hlr := s.program + IO.println (toString hlr) + let headFunction := s.program.functions.head! + -- print graph of function + let g := KLR.HLR.Graph.graph headFunction |> toString + writeContent "dot" p g + -- print HLR program as Python program + let py := KLR.HLR.Py.compile hlr + writeContent "py" p py + return 0 + | (.error e, s) => do + IO.eprintln s!"Error compiling HLO to HLR: {e}" + IO.eprintln s!"{repr s}" + return 1 + | .error e => + IO.eprintln e + return 1 + -- -- Command configuration def gatherCmd := `[Cli| @@ -388,6 +417,15 @@ def evalKLRCmd := `[Cli| kernelFunctionName : String; "Name of the kernel function" ...inputFiles : String; ".npy files corresponding to the inputs to the kernel, in positional order" ] +def hloToHLRCmd := `[Cli| + "hlo-to-hlr" VIA hloToHLR; + "Compile HLO graph to HLR graph" + + FLAGS: + o, outfile : String; "Name of output file" + ARGS: + file : String; "File of HLO graph in .mlir format" +] def klrCmd : Cmd := `[Cli| klr NOOP; ["0.0.12"] @@ -401,6 +439,7 @@ def klrCmd : Cmd := `[Cli| nkiToKLRCmd; traceCmd; typecheckCmd + hloToHLRCmd ] def main (args : List String) : IO UInt32 := do diff --git a/lake-manifest.json b/lake-manifest.json index ccb3050a..2c28de0d 100644 --- a/lake-manifest.json +++ b/lake-manifest.json @@ -8,6 +8,16 @@ "inherited": false, "dir": "KLR", "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": false, + "configFile": "lakefile.lean"}, {"url": "https://github.com/leanprover/TensorLib.git", "type": "git", "subDir": null, diff --git a/lakefile.lean b/lakefile.lean index 34142f2e..d5441db5 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -58,6 +58,9 @@ require plausible from git require TensorLib from git "https://github.com/leanprover/TensorLib.git" @ "v0.0.13" +require SHerLOC from git + "https://github.com/leanprover/SHerLOC.git" @ "c74ae090d4326cca9ff636184c330a67ca039ef6" + -- Comment the above and uncomment this for local development -- require TensorLib from "../TensorLib" From fe59e14c4f6b04befb5cd72958c195ebf151fe4a Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 9 Jul 2025 16:19:49 -0400 Subject: [PATCH 02/15] Fix compilation error in Main --- Main.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Main.lean b/Main.lean index d78a7a59..944cc011 100644 --- a/Main.lean +++ b/Main.lean @@ -438,7 +438,7 @@ def klrCmd : Cmd := `[Cli| infoCmd; nkiToKLRCmd; traceCmd; - typecheckCmd + typecheckCmd; hloToHLRCmd ] From d5407ca7bbd4f0666fb8156b0cd347c61bf0a1d6 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 9 Jul 2025 16:39:37 -0400 Subject: [PATCH 03/15] Fix compilation of constants and broadcast --- KLR/HLR/AST.lean | 32 +++++----- KLR/HLR/Compile.lean | 135 +++++++++++++++++++++++++++++++++---------- KLR/HLR/Py.lean | 6 +- KLR/Util/Hex.lean | 7 +++ 4 files changed, 133 insertions(+), 47 deletions(-) diff --git a/KLR/HLR/AST.lean b/KLR/HLR/AST.lean index 73a1608a..4e51fa3c 100644 --- a/KLR/HLR/AST.lean +++ b/KLR/HLR/AST.lean @@ -12,7 +12,7 @@ import TensorLib.Shape import TensorLib.Slice import TensorLib.Tensor -open TensorLib (Shape Dtype) +open TensorLib (Tensor Shape Dtype Slice) /- The definition of the High-Level Representation (HLR) IR. The goal of this IR is to @@ -46,7 +46,7 @@ inductive UnaryOp where | exp | sqrt | neg - | convert (dtype : TensorLib.Dtype) + | convert (dtype : Dtype) deriving Inhabited, Repr /- @@ -80,25 +80,31 @@ inductive Operator where -- select elements from two tensors based on a condition tensor | select (cond a b : Var) -- create a tensor filled with a specific value, with the given shape - | full (value : Float32) (shape : Shape) + -- Note: the tensor is always a scalar-array + | full (value : Tensor) (shape : Shape) -- transpose a tensor with the provided permutation of dimensions | transpose (a : Var) (dims : List Nat) -- unused | split_with_sizes (a : Var) (sizes : List Nat) -- ?? -- reshape a tensor to the specified shape | reshape (a : Var) (shape : Shape) - -- broadcast a tensor to the specified shape - -- TODO: broadcasting is very complicated and we haven't figured it out yet, - -- so this instruction just passes through the semantics of HLO's broadcasting - | broadcast (a : Var) (shape : Shape) (broadcastDims : List Nat) + /- + broadcast a tensor to the specified shape + + It must be the case that `len(shape(a)) = len(shape)` and that + `∀ i, shape(a)[i] != shape[i] => shape(a)[i] == 1` + In other words, the broadcast cannot produce new dimensions, only expand + existing ones of size 1. + -/ + | broadcast (a : Var) (shape : Shape) -- create a constant tensor with the given values and shape - | const (values : TensorLib.Tensor) (shape : Shape) (dtype : TensorLib.Dtype) + | const (values : Tensor) (shape : Shape) (dtype : Dtype) -- gather elements from a tensor using the provided indices and offset dimensions -- TODO: gather is complicated and not used except for in llama, so for now -- we just pass through the semantics of HLO's gather | gather (input indices : Var) (offsetDims collapsedSliceDims startIndexMap : List Nat) (indexVectorDim : Nat) -- slice a tensor along specified dimensions, with start, limit, and stride - | slice (a : Var) (slice : List TensorLib.Slice) + | slice (a : Var) (slice : List Slice) -- call another function, passing input values and receiving outputs | call (callee : String) (inputValues : List Var) deriving Inhabited, Repr @@ -170,7 +176,7 @@ def findVar (f : Function) (v : Var) : Option Operator := | _ => .none) -- TODO: move these toString instances to the TensorLib repo -instance : ToString TensorLib.Slice where +instance : ToString Slice where toString s := let {start, stop, step, ..} := s let start := start.map toString |>.getD "" @@ -178,11 +184,11 @@ instance : ToString TensorLib.Slice where let step := step.map toString |>.getD "" s!"{start}:{stop}:{step}" -instance : ToString TensorLib.Shape where +instance : ToString Shape where toString s := s.val.map toString |> "x".intercalate |> fun x => s!"[{x}]" -instance : ToString TensorLib.Dtype where +instance : ToString Dtype where toString | .bool => "bool" | .int8 => "i8" @@ -231,7 +237,7 @@ instance : ToString Operator where | .transpose a dims => s!"transpose({a}, dims={dims})" | .split_with_sizes a sizes => s!"split_with_sizes({a}, sizes={sizes})" | .reshape a shape => s!"reshape({a}, shape={shape})" - | .broadcast a shape dims => s!"broadcast({a}, shape={shape}, dims={dims})" + | .broadcast a shape => s!"broadcast({a}, shape={shape})" | .const t shape dtype => s!"const({repr t}, shape={shape}, dtype={dtype})" | .gather a indices offsetDims collapsedSliceDims startIndexMap indexVectorDim => s!" gather({a}, indices={indices}, offsetDims={offsetDims}, collapsedSliceDims={collapsedSliceDims}, startIndexMap={startIndexMap}, indexVectorDim={indexVectorDim})" diff --git a/KLR/HLR/Compile.lean b/KLR/HLR/Compile.lean index 5ee6965d..f1ac4a0f 100644 --- a/KLR/HLR/Compile.lean +++ b/KLR/HLR/Compile.lean @@ -11,6 +11,8 @@ import SHerLOC import TensorLib.Shape import TensorLib.Slice import TensorLib.Tensor +import TensorLib.Bytes +import TensorLib.ByteArray open TensorLib (Dtype Shape Tensor) @@ -84,37 +86,49 @@ def parseFloat (c : StableHLO.Parsing.FloatLiteral) : Float32 := else (true, exponent.natAbs) sign * OfScientific.ofScientific mantissa exponentSign exponent - | .hexaDecimal _ => panic! "Hexadecimal float literals are not supported yet." + | .hexaDecimal n => UInt32.ofNat n |> Float32.ofBits #guard parseFloat (.decimal { integerPart := ⟨ .plus, 4 ⟩, fractionalPart := ⟨ .plus, 785980 ⟩, scientificPart := ⟨ .plus, 3 ⟩}) == 4785.980 #guard parseFloat (.decimal { integerPart := ⟨ .minus, 4 ⟩, fractionalPart := ⟨ .plus, 785980 ⟩, scientificPart := ⟨ .minus, 1 ⟩}) == -0.4785980 #guard parseFloat (.decimal { integerPart := ⟨ .minus, 3 ⟩, fractionalPart := ⟨ .plus, 597620 ⟩, scientificPart := ⟨ .minus, 3 ⟩}) == -0.003597620 --- Parse a StableHLO element literal to a Float32. -def parseFloatFromElementLiteral (c : StableHLO.Parsing.ElementLiteral) : Compile Float32 := - match c with - |(.floatLiteral f) => pure (parseFloat f) - | _ => throw s!"Expected a float literal, but got {repr c}." - --- Convert a list of Float32 values to a TensorLib tensor. -def ofFloatList (ns : List Float32) : Compile Tensor := do - let dtype := TensorLib.Dtype.float32 - let size := dtype.itemsize - let arr := Tensor.zeros dtype (Shape.mk [ns.length]) - let mut data := arr.data - let mut posn := 0 - for n in ns do - let v <- dtype.byteArrayOfFloat32 n - data := v.copySlice 0 data posn size - posn := posn + size - .ok { arr with data := data } +/- +Parses a hex string (starting with "0x" and preceded by a multipleof 8 hex characters) +into a tensor of int32 values. + +Assumes bytes are in big-endian order. +-/ +def parseInt32TensorFromHex (s : String) : Compile Tensor := do + -- Tail recursive helper to convert a list of hex characters to a list of BitVec 32 values. + let rec toBitVec32List (str : List Char) (acc : List (BitVec 32)) : Compile (List (BitVec 32)):= do + match str with + | c0 :: c1 :: c2 :: c3 :: c4 :: c5 :: c6 :: c7 :: rest => + match KLR.Util.Hex.hexCharsToBitVecBE c0 c1 c2 c3 c4 c5 c6 c7 with + | some v => toBitVec32List rest (v :: acc) + | none => throw s!"Invalid hex character sequence: {c0}{c1}{c2}{c3}." + | [] => pure acc.reverse + | _ => throw s!"Hex string must have a number of characters divisible by 8, but got {str}." + + -- Trim off the leading "0x" and convert the rest to a list of BitVec 32 values. + let bvs ← match s.toList with + | '0' :: 'x' :: rest => toBitVec32List rest [] + | _ => throw s!"Hex string must start with '0x', but got {s}." + -- Concatenate the bitvecs to create a little-endian bytearray + let data := bvs.foldr (fun new acc => (TensorLib.toLEByteArray new) ++ acc) ByteArray.empty + pure { + dtype := TensorLib.Dtype.int32, + shape := Shape.mk [bvs.length], + data + } -- Convert a list of tensors to a single tensor by concatenating them along a new first dimension. -def ofTensorList (ns : List Tensor) : Compile Tensor := +def tensorOfTensorList (ns : List Tensor) : Compile Tensor := match ns with | f :: r => do if ! (r.all fun t => t.shape == f.shape) then throw s!"All tensors in the list must have the same shape, but got {repr ns}." + else if ! (r.all fun t => t.dtype == f.dtype) then + throw s!"All tensors in the list must have the same dtype, but got {repr ns}." else let newShape := Shape.mk (ns.length :: f.shape.val) let dtype := f.dtype @@ -128,12 +142,33 @@ def ofTensorList (ns : List Tensor) : Compile Tensor := .ok { arr with data := data } | [] => pure (TensorLib.Tensor.empty TensorLib.Dtype.float32 (Shape.mk [])) +/- +Parse a StableHLO string literal to a Tensor + +Note that the meaning of string literals in StableHLO is not well-defined, +but in practice JAX uses them to hex encode integer tensors. +-/ +def parseStringLiteral := parseInt32TensorFromHex +-- Parse a StableHLO boolean literal to a Bool. +def parseBoolLiteral : StableHLO.Parsing.BooleanLiteral → Bool + | .true => true + | .false => false +-- Parse a StableHLO element literal to an HLR tensor. +def parseElementLiteral : StableHLO.Parsing.ElementLiteral → Compile TensorLib.Tensor + | .floatLiteral f => f |> parseFloat |> TensorLib.Tensor.arrayScalarFloat32! |> pure + | .stringLiteral s => s |> parseStringLiteral + | .booleanLiteral b => b |> parseBoolLiteral |> TensorLib.Tensor.arrayScalarBool! |> pure + | .complexLiteral _ => impossible "unimplemented" + -- Parse a StableHLO dense literal to an HLR tensor. def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor - | .denseDimension values => do - ofTensorList (← values.mapM parseTensorLiteral) - | .denseElements elems => do - (← elems.mapM parseFloatFromElementLiteral) |> ofFloatList + -- special case for singleton tensors so we don't create an extra dimension + | .denseElements [v] => do + parseElementLiteral v + | .denseElements l => do + tensorOfTensorList (← l.mapM parseElementLiteral) + | .denseDimension l => do + tensorOfTensorList (← l.mapM parseTensorLiteral) -- Convert a StableHLO tensor type to an HLR TensorTy. def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do @@ -142,6 +177,16 @@ def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do | .unknown => throw "Can't support tensors with dynamic shape") let dtype ← match t.tensorElementTypeGen with | .classic (.floatType .f32) => pure TensorLib.Dtype.float32 + | .classic (.floatType .f64) => pure TensorLib.Dtype.float64 + | .classic (.integerType {sign := .signed, size := .b8}) => pure TensorLib.Dtype.int8 + | .classic (.integerType {sign := .unsigned, size := .b8}) => pure TensorLib.Dtype.uint8 + | .classic (.integerType {sign := .signed, size := .b16}) => pure TensorLib.Dtype.int16 + | .classic (.integerType {sign := .unsigned, size := .b16}) => pure TensorLib.Dtype.uint16 + | .classic (.integerType {sign := .signed, size := .b32}) => pure TensorLib.Dtype.int32 + | .classic (.integerType {sign := .unsigned, size := .b32}) => pure TensorLib.Dtype.uint32 + | .classic (.integerType {sign := .signed, size := .b64}) => pure TensorLib.Dtype.int64 + | .classic (.integerType {sign := .unsigned, size := .b64}) => pure TensorLib.Dtype.uint64 + | .classic (.booleanType) => pure TensorLib.Dtype.bool | _ => throw s!"Unsupported tensor element type: {repr t.tensorElementTypeGen}" pure (.mk (.mk shape) dtype) @@ -168,7 +213,7 @@ def parseArray (c : StableHLO.Parsing.Literal) : Compile (List Nat) := Parse a Nat from a StableHLO float literal. We need this because integers are often represented as floats in StableHLO. -/ -def parseNatFromFloat (c : StableHLO.Parsing.Literal) : Compile Nat := +def parseNatFromElementLiteral (c : StableHLO.Parsing.Literal) : Compile Nat := match c with | .element (.floatLiteral (.decimal {integerPart, fractionalPart, scientificPart})) => match (fractionalPart.decimal == 0, scientificPart.decimal == 0, integerPart.sign) with @@ -295,10 +340,19 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) log "Compiling constant operation" let valueAttr ← lookupAttributeValue inputAttributes "value" match valueAttr with - | (.tensor (.denseElements [(.floatLiteral f)])) => - pure [.assign output (.full (parseFloat f) outputTy.shape) outputTy] - | (.tensor lit) => - pure [.assign output (.const (← parseTensorLiteral lit) outputTy.shape outputTy.dtype) outputTy] + | (.tensor t) => do + let t ← parseTensorLiteral t + if t.shape == Shape.empty then + -- If the tensor is a scalar-array, we use a `full` operation + -- to create a tensor of the same shape as the output. + pure [.assign output (.full t outputTy.shape) outputTy] + else + -- If the tensor is not a scalar-array, it corresponds to a + -- `const` operation. + if t.shape.count != outputTy.shape.count then + throw s!"Tensor literal shape {t.shape} does not match expected output shape {outputTy.shape}." + let t ← t.reshape outputTy.shape + pure [.assign output (.const t outputTy.shape outputTy.dtype) outputTy] | _ => throw "Constant operation requires a 'value' attribute with tensor literal." -- tensor unary operators | .reshape => do @@ -333,10 +387,29 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) let dims ← (← lookupAttributeValue inputAttributes "dimensions") |> parseArray pure [.assign output (.reductionOp op inputValues[0]! inputValues[1]! dims) outputTy] -- TODO: init value | .broadcastInDim => do + /- + We compile a broadcast by first reshaping to a new tensor `t` with added + dimensions of size 1 such that the rank is equal to the output rank, then + we broadcast `t` to the output shape which expands some of those + dimensions of size 1 + -/ log "Compiling broadcastInDim operation" let input := inputValues[0]! + let inputTy := ← parseTensorTypeFromValueTypes signature.domain 0 let broadcastDims ← (← lookupAttributeValue inputAttributes "broadcast_dimensions") |> parseArray - pure [.assign output (.broadcast input outputTy.shape broadcastDims) outputTy] + let reshaped := input ++ "_reshaped" -- TODO: need fresh var name here + -- A shape that has the same number of dimensions as the output, but where + -- specified dimensions match the input shape, and others are 1. + let newShape := outputTy.shape.ndim |> List.range |> List.map (fun n => + if let .some i := broadcastDims.idxOf? n then + inputTy.shape.val[i]! + else + 1) + |> .mk + pure [ + .assign reshaped (.reshape input newShape) ⟨ newShape, inputTy.dtype ⟩, + .assign output (.broadcast reshaped outputTy.shape) outputTy + ] | .transpose => do log "Compiling transpose operation" let input := inputValues[0]! @@ -406,7 +479,7 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) | .concatenate => log "Compiling concatenate operation" let tensors := inputValues - let dim ← (← lookupAttributeValue inputAttributes "dimension") |> parseNatFromFloat + let dim ← (← lookupAttributeValue inputAttributes "dimension") |> parseNatFromElementLiteral pure [.assign output (.concat tensors dim) outputTy] -- tensor ternary operators | .select => diff --git a/KLR/HLR/Py.lean b/KLR/HLR/Py.lean index d48227a2..6fc8b239 100644 --- a/KLR/HLR/Py.lean +++ b/KLR/HLR/Py.lean @@ -104,14 +104,14 @@ def opToPy (op : Operator) : String := let tensorsStr := String.intercalate "," (tensors.map toString) s!"np.concatenate([{tensorsStr}], axis={dim})" | .select cond a b => s!"np.where({cond}, {varToPy a}, {varToPy b})" - | .full value shape => s!"np.full(({shapeToPy shape}), {value})" + | .full _value shape => s!"np.full(({shapeToPy shape}), 0)" -- TODO: make this use the actual value | .transpose a dims => let dimsStr := dims.map toString |> ", ".intercalate s!"np.transpose({varToPy a}, axes=[{dimsStr}])" | .split_with_sizes .. => panic! s!"Split with sizes operation not implemented in Python translation" | .reshape a shape => s!"{varToPy a}.reshape({shapeToPy shape})" - | .broadcast a shape dims => s!"jax.lax.broadcast_in_dim({varToPy a}, ({shapeToPy shape}), {dims})" - | .const _ shape _ => s!"np.random.random(({shapeToPy shape})" -- TODO: make this use the actual constant value + | .broadcast a shape => s!"np.broadcast_to({varToPy a}, ({shapeToPy shape}))" + | .const _ shape _ => s!"np.random.random(({shapeToPy shape}))" -- TODO: make this use the actual constant value | .gather .. => panic! s!"Gather operation not implemented in Python translation" | .slice .. => panic! s!"Slice operation not implemented in Python translation" | .call .. => diff --git a/KLR/Util/Hex.lean b/KLR/Util/Hex.lean index ea77d554..4136e248 100644 --- a/KLR/Util/Hex.lean +++ b/KLR/Util/Hex.lean @@ -40,6 +40,13 @@ private def hexCharToUInt8 (high : Char) (low : Char) : Option UInt8 := do let lowNibble ← hexCharToNibble low return (highNibble <<< 4) + lowNibble +def hexCharsToBitVecBE (c0 c1 c2 c3 c4 c5 c6 c7: Char) : Option (BitVec 32) := do + let b0 := (← hexCharToUInt8 c0 c1).toBitVec + let b1 := (← hexCharToUInt8 c2 c3).toBitVec + let b2 := (← hexCharToUInt8 c4 c5).toBitVec + let b3 := (← hexCharToUInt8 c6 c7).toBitVec + pure (b0 ++ b1 ++ b2 ++ b3) + def decode (s : String) : Option ByteArray := Id.run do let rec split : List Char -> List (Char × Char) | [] | [_] => [] From cfbfd2136d0de9f96a9d572e28aa9d1b71aae143 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 9 Jul 2025 16:39:45 -0400 Subject: [PATCH 04/15] Add shapes to graph viz --- KLR/HLR/Dot.lean | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/KLR/HLR/Dot.lean b/KLR/HLR/Dot.lean index f59d890a..8c2af992 100644 --- a/KLR/HLR/Dot.lean +++ b/KLR/HLR/Dot.lean @@ -29,7 +29,7 @@ def makeReturnNode (funcName : String) : Vertex := ("fillcolor", "lightgray"), ("color", "gray") ]) -def makeOpNode (op : Operator) (output : String) : Vertex := +def makeOpNode (op : Operator) (output : String) (ty : KLR.HLR.TensorTy): Vertex := let attrs := match op with | .arg .. => [ ("shape", "diamond"), @@ -51,7 +51,7 @@ def makeOpNode (op : Operator) (output : String) : Vertex := .mk (sanitize output) (.mk ([ - ("label", s!"{opName op}\\n{output}"), + ("label", s!"{opName op}\\n{output}\n{ty.shape}"), ] ++ attrs)) def makeConstNode (name : String) (usedAt : String): Vertex := @@ -104,10 +104,10 @@ def graph (f : HLR.Function) : Graph := Id.run do for s in f.statements do match s with | .assign _ (.const _ _ _) _ => () - | .assign v op _ => + | .assign v op ty => let deps := dependencies op |>.map sanitize let (newVertices, newEdges) := makeEdges deps (sanitize v) - vertices := [makeOpNode op v] ++ newVertices ++ vertices + vertices := [makeOpNode op v ty] ++ newVertices ++ vertices edges := newEdges ++ edges | .ret vars => let node := makeReturnNode f.name From 7b994d582b8ce7fc7f206e915fca1550f9845e4c Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 9 Jul 2025 16:47:59 -0400 Subject: [PATCH 05/15] Add fresh symbol generation to HLO compilation --- KLR/HLR/Compile.lean | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/KLR/HLR/Compile.lean b/KLR/HLR/Compile.lean index f1ac4a0f..7c10108d 100644 --- a/KLR/HLR/Compile.lean +++ b/KLR/HLR/Compile.lean @@ -19,12 +19,15 @@ open TensorLib (Dtype Shape Tensor) -- This module compiles a StableHLO program into an HLR program. namespace KLR.HLR.Compile +abbrev SymbolEnv := List (String × Nat) + -- Context for the compilation process, to be stored in a state monad. structure Ctx where -- the program being compiled program : Program -- the log of messages generated during compilation (for debugging) log : List String + gensymEnv : SymbolEnv deriving Inhabited, Repr -- Compilation requires tracking state and also potentially returning errors. @@ -39,6 +42,21 @@ def addFunction (func : Function) : Compile Unit := do modify (fun ctx => { ctx with program := { ctx.program with functions := ctx.program.functions ++ [func] } }) +/- +Generate a fresh variable name based on a given name. + +TODO: This does not actually guarantee that the name is unique, since it +only checks the gensymEnv. We also need to check against all existing +variables in the program. +-/ +def gensym (name : String) : Compile Var := do + let ctx ← get + let idx := match ctx.gensymEnv.find? (fun ⟨ n, _ ⟩ => n == name) with + | some (_, i) => i + 1 + | none => 0 + modify (fun ctx => { ctx with gensymEnv := (name, idx) :: ctx.gensymEnv }) + pure s!"{name}_{idx}" + -- Permute `l` according to the indices in `permutation`. def permute {T : Type} (l : List T) (permutation : List Nat) : Option (List T) := permutation.mapM fun dim => l[dim]? @@ -397,7 +415,7 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) let input := inputValues[0]! let inputTy := ← parseTensorTypeFromValueTypes signature.domain 0 let broadcastDims ← (← lookupAttributeValue inputAttributes "broadcast_dimensions") |> parseArray - let reshaped := input ++ "_reshaped" -- TODO: need fresh var name here + let reshaped ← gensym (input ++ "_reshaped") -- A shape that has the same number of dimensions as the output, but where -- specified dimensions match the input shape, and others are 1. let newShape := outputTy.shape.ndim |> List.range |> List.map (fun n => @@ -454,16 +472,15 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) let rhsResultSize := if rhsResultShape.isEmpty then 1 else rhsResultShape.foldl (· * ·) (1 : Nat) let contractingSize := if contractingShape.isEmpty then 1 else contractingShape.foldl (· * ·) 1 -- Create fresh variable names for intermediate results - -- TODO: this is currently not correct, since the names are not unique - let lhsTransposedName := lhs ++ "_transposed" - let rhsTransposedName := rhs ++ "_transposed" - let lhsReshapedName := lhs ++ "_reshaped" + let lhsTransposedName ← gensym (lhs ++ "_transposed") + let rhsTransposedName ← gensym (rhs ++ "_transposed") + let lhsReshapedName ← gensym (lhs ++ "_reshaped") let lhsReshapedShape := [batchSize, lhsResultSize, contractingSize] let lhsReshapedTy := TensorTy.mk (.mk lhsReshapedShape) dtype - let rhsReshapedName := rhs ++ "_reshaped" + let rhsReshapedName ← gensym (rhs ++ "_reshaped") let rhsReshapedShape := [batchSize, rhsResultSize, contractingSize] let rhsReshapedTy := TensorTy.mk (.mk rhsReshapedShape) dtype - let resultReshapedName := output ++ "_reshaped" + let resultReshapedName ← gensym (output ++ "_reshaped") let resultReshapedShape := [batchSize, lhsResultSize, rhsResultSize] let resultReshapedType := TensorTy.mk (.mk resultReshapedShape) dtype -- Emit the HLR statements for the dotGeneral operation From cc0fe0daa2089364d33943f725ef7307c4752693 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Tue, 15 Jul 2025 15:58:03 -0400 Subject: [PATCH 06/15] Make logging less fine-grained --- KLR/HLR/Compile.lean | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/KLR/HLR/Compile.lean b/KLR/HLR/Compile.lean index 7c10108d..e6cf3b8f 100644 --- a/KLR/HLR/Compile.lean +++ b/KLR/HLR/Compile.lean @@ -321,7 +321,9 @@ This means it's possible that if there's a mistake in the shape calculation in the HLO program, the HLR statements will also have incorrect shapes. Eventually, we'll want a function that can shape-check an HLR program. -/ -def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) +def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := do + log s!"Compiling operation {repr op}" + match op with | .stablehlo opCode inputValues inputFunctions inputAttributes outputs signature => do -- Reuse the variable names and shapes from the StableHLO program let output ← match outputs with @@ -330,12 +332,10 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) let outputTy ← parseSingleTensorTypeFromValueTypes signature.range -- helper function to emit HLR for element-wise unary ops let makeUnOp := fun (op : UnaryOp) => do - log s!"Compiling unary op {op}" let a := inputValues[0]! pure [.assign output (.unaryOp op a) outputTy] -- helper function to emit HLR for element-wise binary ops let makeBinOp := fun (op : BinaryOp) => do - log s!"Compiling binary op {op}" let a := inputValues[0]! let b := inputValues[1]! pure [.assign output (.binaryOp op a b) outputTy] @@ -355,7 +355,6 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) | .divide => makeBinOp .div -- tensor nullary operators | .constant => do - log "Compiling constant operation" let valueAttr ← lookupAttributeValue inputAttributes "value" match valueAttr with | (.tensor t) => do @@ -374,18 +373,15 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) | _ => throw "Constant operation requires a 'value' attribute with tensor literal." -- tensor unary operators | .reshape => do - log "reshape" let input := inputValues[0]! pure [.assign output (.reshape input outputTy.shape) outputTy] | .gather => - log "Compiling gather operation" let (offsetDims, collapsedSliceDims, startIndexMap, indexVectorDim) ← extractDimensionNumbers inputAttributes let input := inputValues[0]! let indices := inputValues[1]! pure [.assign output (.gather input indices offsetDims collapsedSliceDims startIndexMap indexVectorDim) outputTy] | .slice => - log "Compiling slice operation" let input := inputValues[0]! let start ← (← lookupAttributeValue inputAttributes "start_indices") |> parseArray let limit ← (← lookupAttributeValue inputAttributes "limit_indices") |> parseArray @@ -400,7 +396,6 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) (.some $ Int.ofNat stride[i]!)) pure [.assign output (.slice input slices) outputTy] | .reduce => - log "Compiling reduce operation" let op ← reduceFunctionToReduceOp inputFunctions[0]! let dims ← (← lookupAttributeValue inputAttributes "dimensions") |> parseArray pure [.assign output (.reductionOp op inputValues[0]! inputValues[1]! dims) outputTy] -- TODO: init value @@ -411,7 +406,6 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) we broadcast `t` to the output shape which expands some of those dimensions of size 1 -/ - log "Compiling broadcastInDim operation" let input := inputValues[0]! let inputTy := ← parseTensorTypeFromValueTypes signature.domain 0 let broadcastDims ← (← lookupAttributeValue inputAttributes "broadcast_dimensions") |> parseArray @@ -429,7 +423,6 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) .assign output (.broadcast reshaped outputTy.shape) outputTy ] | .transpose => do - log "Compiling transpose operation" let input := inputValues[0]! let dims ← (← lookupAttributeValue inputAttributes "permutation") |> parseArray pure [.assign output (.transpose input dims) outputTy] @@ -441,7 +434,6 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) variables in the spec to aid comprehension. https://github.com/openxla/stablehlo/blob/6f7b4ab8f96dc65cf3c8e9824836117d2934cc45/docs/spec.md?#dot_general -/ - log "Compiling dotGeneral operation" -- Gather metadata from the inputs let (lhsBatchingDims, lhsContractingDims, rhsBatchingDims, rhsContractingDims) ← extractDotDimensionNumbers inputAttributes @@ -494,23 +486,19 @@ def compileOp : StableHLO.Parsing.Operation → Compile (List Statement) .assign output (.reshape resultReshapedName (.mk resultShape)) outputTy, ]) | .concatenate => - log "Compiling concatenate operation" let tensors := inputValues let dim ← (← lookupAttributeValue inputAttributes "dimension") |> parseNatFromElementLiteral pure [.assign output (.concat tensors dim) outputTy] -- tensor ternary operators | .select => - log "Compiling select operation" let cond := inputValues[0]! let a := inputValues[1]! let b := inputValues[2]! pure [.assign output (.select cond a b) outputTy] | _ => throw s!"Unsupported HLO operation: {repr opCode}" | .return ops _ => do - log "Compiling return operation" pure [Statement.ret ops] | .call callee inputValues outputs signature => do - log "Compiling call operation" let output ← match outputs with | [output] => pure output | _ => throw "Call operator signature must have a single output." From bd5051840603a59759a372fc14b9c2788174f826 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Tue, 15 Jul 2025 16:02:35 -0400 Subject: [PATCH 07/15] Rename HLR->TGR --- KLR/HLR.lean | 2 +- KLR/HLR/AST.lean | 16 ++++++++-------- KLR/HLR/Basic.lean | 8 ++++---- KLR/HLR/Compile.lean | 36 ++++++++++++++++++------------------ KLR/HLR/Dot.lean | 14 +++++++------- KLR/HLR/Py.lean | 12 ++++++------ Main.lean | 22 +++++++++++----------- 7 files changed, 55 insertions(+), 55 deletions(-) diff --git a/KLR/HLR.lean b/KLR/HLR.lean index 4ce2b3ec..0a76b340 100644 --- a/KLR/HLR.lean +++ b/KLR/HLR.lean @@ -1 +1 @@ -import KLR.HLR.Basic +import KLR.TGR.Basic diff --git a/KLR/HLR/AST.lean b/KLR/HLR/AST.lean index 4e51fa3c..55e9eda2 100644 --- a/KLR/HLR/AST.lean +++ b/KLR/HLR/AST.lean @@ -15,13 +15,13 @@ import TensorLib.Tensor open TensorLib (Tensor Shape Dtype Slice) /- -The definition of the High-Level Representation (HLR) IR. The goal of this IR is to +The definition of the Tensor Graph Representation (TGR) IR. The goal of this IR is to be a uniform representation for graphs of tensor operations, which we can use as a common compilation target for different frontends (e.g. StableHLO, PyTorch FX, etc.). -A HLR program consists of a list of functions, each with a name, and input and output tensors. +A TGR program consists of a list of functions, each with a name, and input and output tensors. The function body is in SSA, with each operation producing a single output tensor. -/ -namespace KLR.HLR +namespace KLR.TGR structure TensorTy where shape : Shape @@ -50,7 +50,7 @@ inductive UnaryOp where deriving Inhabited, Repr /- -Operators in the HLR (High-Level Representation) of KLR. +Operators in the TGR (Tensor Graph Representation) IR. Note: some HLO operations have "load-bearing" output shapes, meaning the output shape is a vital part of the operation's semantics (e.g. `reshape`). @@ -110,7 +110,7 @@ inductive Operator where deriving Inhabited, Repr /- -A statement in HLR (High Level Representation). +A statement in TGR (Tensor Graph Representation). In SSA form, so each variable is assigned exactly once. -/ inductive Statement where @@ -128,7 +128,7 @@ inductive Statement where deriving Inhabited, Repr /- -An HLR function. Note that arguments are referred to by index, so +An TGR function. Note that arguments are referred to by index, so we only store the argument shapes, not names. -/ structure Function where @@ -138,7 +138,7 @@ structure Function where statements : List Statement deriving Inhabited, Repr, Nonempty --- An HLR program +-- An TGR program structure Program where functions : List Function deriving Inhabited, Repr, Nonempty @@ -284,4 +284,4 @@ def opName : Operator → String | .slice .. => s!"slice" | .call callee .. => s!"call {callee}" -end KLR.HLR +end KLR.TGR diff --git a/KLR/HLR/Basic.lean b/KLR/HLR/Basic.lean index 609cfa1b..d3985ea8 100644 --- a/KLR/HLR/Basic.lean +++ b/KLR/HLR/Basic.lean @@ -1,4 +1,4 @@ -import KLR.HLR.AST -import KLR.HLR.Compile -import KLR.HLR.Dot -import KLR.HLR.Py +import KLR.TGR.AST +import KLR.TGR.Compile +import KLR.TGR.Dot +import KLR.TGR.Py diff --git a/KLR/HLR/Compile.lean b/KLR/HLR/Compile.lean index e6cf3b8f..fba10586 100644 --- a/KLR/HLR/Compile.lean +++ b/KLR/HLR/Compile.lean @@ -5,7 +5,7 @@ Authors: Paul Biberstein -/ import KLR.Core.Operators -import KLR.HLR.AST +import KLR.TGR.AST import KLR.Util import SHerLOC import TensorLib.Shape @@ -16,8 +16,8 @@ import TensorLib.ByteArray open TensorLib (Dtype Shape Tensor) --- This module compiles a StableHLO program into an HLR program. -namespace KLR.HLR.Compile +-- This module compiles a StableHLO program into an TGR program. +namespace KLR.TGR.Compile abbrev SymbolEnv := List (String × Nat) @@ -171,14 +171,14 @@ def parseStringLiteral := parseInt32TensorFromHex def parseBoolLiteral : StableHLO.Parsing.BooleanLiteral → Bool | .true => true | .false => false --- Parse a StableHLO element literal to an HLR tensor. +-- Parse a StableHLO element literal to an TGR tensor. def parseElementLiteral : StableHLO.Parsing.ElementLiteral → Compile TensorLib.Tensor | .floatLiteral f => f |> parseFloat |> TensorLib.Tensor.arrayScalarFloat32! |> pure | .stringLiteral s => s |> parseStringLiteral | .booleanLiteral b => b |> parseBoolLiteral |> TensorLib.Tensor.arrayScalarBool! |> pure | .complexLiteral _ => impossible "unimplemented" --- Parse a StableHLO dense literal to an HLR tensor. +-- Parse a StableHLO dense literal to an TGR tensor. def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor -- special case for singleton tensors so we don't create an extra dimension | .denseElements [v] => do @@ -188,7 +188,7 @@ def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor | .denseDimension l => do tensorOfTensorList (← l.mapM parseTensorLiteral) --- Convert a StableHLO tensor type to an HLR TensorTy. +-- Convert a StableHLO tensor type to an TGR TensorTy. def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do let shape ← t.shape.mapM (fun | .known d => pure d @@ -208,14 +208,14 @@ def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do | _ => throw s!"Unsupported tensor element type: {repr t.tensorElementTypeGen}" pure (.mk (.mk shape) dtype) --- Parse an HLR TensorTy at index `n` from the list of types. +-- Parse an TGR TensorTy at index `n` from the list of types. def parseTensorTypeFromValueTypes (l : List StableHLO.Parsing.ValueType) (n : Nat): Compile TensorTy := match l[n]? with | .some (.tensorType t) => parseTensorType t | .some t => throw s!"Element {n} of type list must have tensor type, but got {repr t}." | _ => throw s!"Type list must have at least {n + 1} values, but got only {l.length}." --- Parse an HLR TensorTy from a list of types, expecting the list to have exactly one element. +-- Parse an TGR TensorTy from a list of types, expecting the list to have exactly one element. def parseSingleTensorTypeFromValueTypes : List StableHLO.Parsing.ValueType → Compile TensorTy | [.tensorType t] => parseTensorType t | t => throw s!"Expected type list to have a single tensor type, but got {repr t}." @@ -293,13 +293,13 @@ def extractDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile /- The StableHLO `reduce` operation always calls an arbitrary reduction function. -However, in HLR we only support a few specific reduction operations (mostly +However, in TGR we only support a few specific reduction operations (mostly arithmetic and logical binary operators). Since many StableHLO programs only use these basic reduction operations, we can recognize when the StableHLO function called by a `reduce` operation is one of these basic operations, and convert it -to the corresponding HLR BinaryOp. +to the corresponding TGR BinaryOp. If this process is unsuccessful, it means that the input `reduce` function is more -complicated and can't be supported by the current HLR design. +complicated and can't be supported by the current TGR design. -/ def reduceFunctionToReduceOp (f : StableHLO.Parsing.InputFunc) : Compile (BinaryOp) := do match f with @@ -313,13 +313,13 @@ def reduceFunctionToReduceOp (f : StableHLO.Parsing.InputFunc) : Compile (Binary ++ s!"Function: {repr op}") /- -Compile a StableHLO operation into a list of HLR statements. +Compile a StableHLO operation into a list of TGR statements. Note: this function annotates each statement with the type of its output, but this type is merely passed through from the HLO program, not computed anew. This means it's possible that if there's a mistake in the shape calculation -in the HLO program, the HLR statements will also have incorrect shapes. -Eventually, we'll want a function that can shape-check an HLR program. +in the HLO program, the TGR statements will also have incorrect shapes. +Eventually, we'll want a function that can shape-check an TGR program. -/ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := do log s!"Compiling operation {repr op}" @@ -330,11 +330,11 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := | [output] => pure output | _ => throw "Operator signature must have a single output." let outputTy ← parseSingleTensorTypeFromValueTypes signature.range - -- helper function to emit HLR for element-wise unary ops + -- helper function to emit TGR for element-wise unary ops let makeUnOp := fun (op : UnaryOp) => do let a := inputValues[0]! pure [.assign output (.unaryOp op a) outputTy] - -- helper function to emit HLR for element-wise binary ops + -- helper function to emit TGR for element-wise binary ops let makeBinOp := fun (op : BinaryOp) => do let a := inputValues[0]! let b := inputValues[1]! @@ -475,7 +475,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let resultReshapedName ← gensym (output ++ "_reshaped") let resultReshapedShape := [batchSize, lhsResultSize, rhsResultSize] let resultReshapedType := TensorTy.mk (.mk resultReshapedShape) dtype - -- Emit the HLR statements for the dotGeneral operation + -- Emit the TGR statements for the dotGeneral operation pure ([ .comment "Dot General Operation", .assign lhsTransposedName (.transpose lhs lhsTransposePerm) (.mk (.mk lhsTransposedShape) dtype), @@ -541,4 +541,4 @@ def compile (m : List StableHLO.Parsing.Module) : (Except String Unit) × Ctx := | .ok _ s => (.ok (), s) | .error err s => (throw err, s) -end KLR.HLR.Compile +end KLR.TGR.Compile diff --git a/KLR/HLR/Dot.lean b/KLR/HLR/Dot.lean index 8c2af992..18b209c1 100644 --- a/KLR/HLR/Dot.lean +++ b/KLR/HLR/Dot.lean @@ -4,13 +4,13 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Biberstein -/ -import KLR.HLR.AST +import KLR.TGR.AST import SHerLOC.Analysis.Graph open StableHLO.Analysis (Graph Edge Vertex) --- This module provides a way to convert an HLR function into a DOT graph representation. -namespace KLR.HLR.Graph +-- This module provides a way to convert an TGR function into a DOT graph representation. +namespace KLR.TGR.Graph /- Process the name `var` so that it can used as a node ID in DOT format. @@ -29,7 +29,7 @@ def makeReturnNode (funcName : String) : Vertex := ("fillcolor", "lightgray"), ("color", "gray") ]) -def makeOpNode (op : Operator) (output : String) (ty : KLR.HLR.TensorTy): Vertex := +def makeOpNode (op : Operator) (output : String) (ty : KLR.TGR.TensorTy): Vertex := let attrs := match op with | .arg .. => [ ("shape", "diamond"), @@ -72,14 +72,14 @@ def makeEdge (source : String) (dest : String) : Edge := (.mk []) /- -Convert an HLR function to a DOT graph, where each variable is a vertex +Convert an TGR function to a DOT graph, where each variable is a vertex and an edge exists from A to B if A is used in the computation of B. Note: since constants are reused in many parts of the function, they can cause the graph to have long edges that cross over other nodes. To avoid this, we create a separate vertex for each use of a constant. -/ -def graph (f : HLR.Function) : Graph := Id.run do +def graph (f : TGR.Function) : Graph := Id.run do let mut vertices := [] let mut edges := [] -- Every variables in the function that is the result of a `constant` operatior @@ -119,4 +119,4 @@ def graph (f : HLR.Function) : Graph := Id.run do .mk f.name vertices edges -end KLR.HLR.Graph +end KLR.TGR.Graph diff --git a/KLR/HLR/Py.lean b/KLR/HLR/Py.lean index 6fc8b239..9f4e744b 100644 --- a/KLR/HLR/Py.lean +++ b/KLR/HLR/Py.lean @@ -5,7 +5,7 @@ Authors: Paul Biberstein -/ import KLR.Core.Operators -import KLR.HLR.AST +import KLR.TGR.AST import KLR.Util import SHerLOC import TensorLib.Shape @@ -15,12 +15,12 @@ open Std.Format open TensorLib (Dtype Shape Slice) /- -This module converts an HLR program into a runnable Python program. -At present, it can't convert HLR constants to python constants and can't +This module converts an TGR program into a runnable Python program. +At present, it can't convert TGR constants to python constants and can't take input tensors, so it is only helpful to ensure that the shape annotations are correct and that the program is well-formed. -/ -namespace KLR.HLR.Py +namespace KLR.TGR.Py structure FormatCtx where indent : Nat := 0 @@ -148,8 +148,8 @@ def compileProgram (p : Program) : Format := p.functions.map compileFunction joinSep lines line --- Compile the HLR program to a Python program. +-- Compile the TGR program to a Python program. def compile (p : Program) : String := (compileProgram p).pretty -end KLR.HLR.Py +end KLR.TGR.Py diff --git a/Main.lean b/Main.lean index 944cc011..3e6a6caf 100644 --- a/Main.lean +++ b/Main.lean @@ -19,7 +19,7 @@ import KLR import TensorLib.Npy import TensorLib.Tensor import SHerLOC -import KLR.HLR +import KLR.TGR import SHerLOC.Analysis.Graph open Cli @@ -301,26 +301,26 @@ def evalKLR (p : Parsed) : IO UInt32 := do TensorLib.Npy.Ndarray.save! arr filename return 0 -def hloToHLR (p : Parsed) : IO UInt32 := do +def hloToTGR (p : Parsed) : IO UInt32 := do let file := p.positionalArg! "file" |>.as! String let s <- IO.FS.readFile file match StableHLO.Parsing.parse s with | .ok (hlo, _) => - let hlr := KLR.HLR.Compile.compile hlo + let hlr := KLR.TGR.Compile.compile hlo match hlr with | (.ok _, s) => do let hlr := s.program IO.println (toString hlr) let headFunction := s.program.functions.head! -- print graph of function - let g := KLR.HLR.Graph.graph headFunction |> toString + let g := KLR.TGR.Graph.graph headFunction |> toString writeContent "dot" p g - -- print HLR program as Python program - let py := KLR.HLR.Py.compile hlr + -- print TGR program as Python program + let py := KLR.TGR.Py.compile hlr writeContent "py" p py return 0 | (.error e, s) => do - IO.eprintln s!"Error compiling HLO to HLR: {e}" + IO.eprintln s!"Error compiling HLO to TGR: {e}" IO.eprintln s!"{repr s}" return 1 | .error e => @@ -417,9 +417,9 @@ def evalKLRCmd := `[Cli| kernelFunctionName : String; "Name of the kernel function" ...inputFiles : String; ".npy files corresponding to the inputs to the kernel, in positional order" ] -def hloToHLRCmd := `[Cli| - "hlo-to-hlr" VIA hloToHLR; - "Compile HLO graph to HLR graph" +def hloToTGRCmd := `[Cli| + "hlo-to-hlr" VIA hloToTGR; + "Compile HLO graph to TGR graph" FLAGS: o, outfile : String; "Name of output file" @@ -439,7 +439,7 @@ def klrCmd : Cmd := `[Cli| nkiToKLRCmd; traceCmd; typecheckCmd; - hloToHLRCmd + hloToTGRCmd ] def main (args : List String) : IO UInt32 := do From b0b0bac8fcaae2fd0ba9a83029f6c183de5c3214 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Tue, 15 Jul 2025 16:06:21 -0400 Subject: [PATCH 08/15] Finish renaming --- KLR/KLR/Compile.lean | 0 KLR/{HLR.lean => TGR.lean} | 0 KLR/{HLR => TGR}/AST.lean | 0 KLR/{HLR => TGR}/Basic.lean | 0 KLR/{HLR => TGR}/Compile.lean | 0 KLR/{HLR => TGR}/Dot.lean | 0 KLR/{HLR => TGR}/Py.lean | 0 Main.lean | 12 ++++++------ 8 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 KLR/KLR/Compile.lean rename KLR/{HLR.lean => TGR.lean} (100%) rename KLR/{HLR => TGR}/AST.lean (100%) rename KLR/{HLR => TGR}/Basic.lean (100%) rename KLR/{HLR => TGR}/Compile.lean (100%) rename KLR/{HLR => TGR}/Dot.lean (100%) rename KLR/{HLR => TGR}/Py.lean (100%) diff --git a/KLR/KLR/Compile.lean b/KLR/KLR/Compile.lean new file mode 100644 index 00000000..e69de29b diff --git a/KLR/HLR.lean b/KLR/TGR.lean similarity index 100% rename from KLR/HLR.lean rename to KLR/TGR.lean diff --git a/KLR/HLR/AST.lean b/KLR/TGR/AST.lean similarity index 100% rename from KLR/HLR/AST.lean rename to KLR/TGR/AST.lean diff --git a/KLR/HLR/Basic.lean b/KLR/TGR/Basic.lean similarity index 100% rename from KLR/HLR/Basic.lean rename to KLR/TGR/Basic.lean diff --git a/KLR/HLR/Compile.lean b/KLR/TGR/Compile.lean similarity index 100% rename from KLR/HLR/Compile.lean rename to KLR/TGR/Compile.lean diff --git a/KLR/HLR/Dot.lean b/KLR/TGR/Dot.lean similarity index 100% rename from KLR/HLR/Dot.lean rename to KLR/TGR/Dot.lean diff --git a/KLR/HLR/Py.lean b/KLR/TGR/Py.lean similarity index 100% rename from KLR/HLR/Py.lean rename to KLR/TGR/Py.lean diff --git a/Main.lean b/Main.lean index 3e6a6caf..2813d5f6 100644 --- a/Main.lean +++ b/Main.lean @@ -306,17 +306,17 @@ def hloToTGR (p : Parsed) : IO UInt32 := do let s <- IO.FS.readFile file match StableHLO.Parsing.parse s with | .ok (hlo, _) => - let hlr := KLR.TGR.Compile.compile hlo - match hlr with + let tgr := KLR.TGR.Compile.compile hlo + match tgr with | (.ok _, s) => do - let hlr := s.program - IO.println (toString hlr) + let tgr := s.program + IO.println (toString tgr) let headFunction := s.program.functions.head! -- print graph of function let g := KLR.TGR.Graph.graph headFunction |> toString writeContent "dot" p g -- print TGR program as Python program - let py := KLR.TGR.Py.compile hlr + let py := KLR.TGR.Py.compile tgr writeContent "py" p py return 0 | (.error e, s) => do @@ -418,7 +418,7 @@ def evalKLRCmd := `[Cli| ...inputFiles : String; ".npy files corresponding to the inputs to the kernel, in positional order" ] def hloToTGRCmd := `[Cli| - "hlo-to-hlr" VIA hloToTGR; + "hlo-to-tgr" VIA hloToTGR; "Compile HLO graph to TGR graph" FLAGS: From 4a924672138961324805f05d9998e5bd9f4d179f Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Thu, 17 Jul 2025 11:51:35 -0400 Subject: [PATCH 09/15] Fix tgr graph creation --- KLR/TGR/Dot.lean | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/KLR/TGR/Dot.lean b/KLR/TGR/Dot.lean index 18b209c1..cd0dd39f 100644 --- a/KLR/TGR/Dot.lean +++ b/KLR/TGR/Dot.lean @@ -51,14 +51,14 @@ def makeOpNode (op : Operator) (output : String) (ty : KLR.TGR.TensorTy): Vertex .mk (sanitize output) (.mk ([ - ("label", s!"{opName op}\\n{output}\n{ty.shape}"), + ("label", s!"{opName op}\\n{output}\\n{ty.shape}"), ] ++ attrs)) -def makeConstNode (name : String) (usedAt : String): Vertex := +def makeConstNode (op : String) (name : String) (shape : TensorTy) (usedBy : String): Vertex := .mk - s!"node_const_{name}_{usedAt}" + s!"node_const_{name}_{usedBy}" (.mk [ - ("label", s!"const\\n{name} ({usedAt})"), + ("label", s!"{op}\\n{name}\\n{shape.shape}"), ("shape", "diamond"), ("style", "filled"), ("fillcolor", "lightyellow"), @@ -84,7 +84,8 @@ def graph (f : TGR.Function) : Graph := Id.run do let mut edges := [] -- Every variables in the function that is the result of a `constant` operatior let mut consts := f.statements.filterMap (fun - | .assign v (.const _ _ _) _ => .some v + | .assign v (.const ..) shape => .some ("const", v, shape) + | .assign v (.full ..) shape => .some ("full", v, shape) | _ => .none) -- A closure that creates edges from a list of inputs to an output variable. -- If the input is a constant, it creates a vertex for that constant. @@ -92,31 +93,31 @@ def graph (f : TGR.Function) : Graph := Id.run do let mut vertices := [] let mut edges := [] for input in inputs do - if consts.contains input then - let node := makeConstNode input output + if let .some (op, v, shape) := consts.find? fun (_, v, _) => v == input then + let node := makeConstNode op v shape output vertices := node :: vertices edges := (makeEdge node.id output) :: edges else - edges := (makeEdge input output) :: edges + edges := (makeEdge (sanitize input) output) :: edges return (vertices, edges) -- Walk the program statements and create vertices and edges. for s in f.statements do match s with - | .assign _ (.const _ _ _) _ => () + | .assign _ (.const ..) _ | .assign _ (.full ..) _ => pure () | .assign v op ty => - let deps := dependencies op |>.map sanitize - let (newVertices, newEdges) := makeEdges deps (sanitize v) + let deps := dependencies op + let (newVertices, newEdges) ← makeEdges deps (sanitize v) vertices := [makeOpNode op v ty] ++ newVertices ++ vertices edges := newEdges ++ edges | .ret vars => let node := makeReturnNode f.name - let deps := vars.map sanitize - let (newVertices, newEdges) := makeEdges deps node.id + let deps := vars + let (newVertices, newEdges) ← makeEdges deps node.id vertices := [node] ++ newVertices ++ vertices edges := newEdges ++ edges - | .comment _ => () + | .comment _ => pure () - .mk f.name vertices edges + pure $ .mk f.name vertices edges end KLR.TGR.Graph From f8d02da391f54b7c622a2d4c5bda32afe5da11bc Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Fri, 18 Jul 2025 11:24:29 -0400 Subject: [PATCH 10/15] Bust CI cache --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b381cc47..2f496d50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,8 +49,8 @@ jobs: # Leaving in this workaround for some spurious crashes during the build, due to a bug in lean-action, # in case it's not really fixed. # https://github.com/leanprover/lean-action/issues/116#issuecomment-2663316227 - # with: - # use-github-cache: false + with: + use-github-cache: false - name: Run Lean tests run: lake exe klr From 9b09d4811f91c92b0a06b0d8fda553cd4ec29ba4 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Fri, 18 Jul 2025 11:26:31 -0400 Subject: [PATCH 11/15] Restore cache --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2f496d50..b381cc47 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,8 +49,8 @@ jobs: # Leaving in this workaround for some spurious crashes during the build, due to a bug in lean-action, # in case it's not really fixed. # https://github.com/leanprover/lean-action/issues/116#issuecomment-2663316227 - with: - use-github-cache: false + # with: + # use-github-cache: false - name: Run Lean tests run: lake exe klr From c33bb720505647dc824645eb4668d8b7bc96e3f1 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Fri, 18 Jul 2025 11:52:44 -0400 Subject: [PATCH 12/15] Update manifest of gzip lib --- KLR/Util/Gzip/lake-manifest.json | 36 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/KLR/Util/Gzip/lake-manifest.json b/KLR/Util/Gzip/lake-manifest.json index b44f34e4..8ed88edd 100644 --- a/KLR/Util/Gzip/lake-manifest.json +++ b/KLR/Util/Gzip/lake-manifest.json @@ -25,6 +25,16 @@ "inputRev": "v4.21.0", "inherited": false, "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": true, + "configFile": "lakefile.lean"}, {"url": "https://github.com/leanprover/TensorLib.git", "type": "git", "subDir": null, @@ -39,38 +49,38 @@ "type": "git", "subDir": null, "scope": "", - "rev": "c4aa78186d388e50a436e8362b947bae125a2933", + "rev": "2ac43674e92a695e96caac19f4002b25434636da", "name": "plausible", "manifestFile": "lake-manifest.json", - "inputRev": "v4.21.0", + "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover-community/batteries", + {"url": "https://github.com/leanprover-community/import-graph.git", "type": "git", "subDir": null, - "scope": "leanprover-community", - "rev": "8d2067bf518731a70a255d4a61b5c103922c772e", - "name": "batteries", + "scope": "", + "rev": "a11bcb5238149ae5d8a0aa5e2f8eddf8a3a9b27d", + "name": "importGraph", "manifestFile": "lake-manifest.json", - "inputRev": "v4.21.0", + "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}, {"url": "https://github.com/leanprover-community/aesop", "type": "git", "subDir": null, "scope": "", - "rev": "8ff27701d003456fd59f13a9212431239d902aef", + "rev": "ddfca7829bf8aa4083cdf9633935dddbb28b7b2a", "name": "aesop", "manifestFile": "lake-manifest.json", - "inputRev": "v4.21.0", + "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover-community/import-graph.git", + {"url": "https://github.com/leanprover-community/batteries", "type": "git", "subDir": null, - "scope": "", - "rev": "a11bcb5238149ae5d8a0aa5e2f8eddf8a3a9b27d", - "name": "importGraph", + "scope": "leanprover-community", + "rev": "7a0d63fbf8fd350e891868a06d9927efa545ac1e", + "name": "batteries", "manifestFile": "lake-manifest.json", "inputRev": "v4.20.0", "inherited": true, From d196fe8a0560ccf58f623a75be62bc1aaba7bc7b Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 23 Jul 2025 10:44:40 -0400 Subject: [PATCH 13/15] Replace single line comments with multiline --- KLR/TGR/AST.lean | 63 ++++++++++++++--------------- KLR/TGR/Compile.lean | 94 ++++++++++++++++++++++---------------------- KLR/TGR/Dot.lean | 10 ++--- KLR/TGR/Py.lean | 4 +- 4 files changed, 83 insertions(+), 88 deletions(-) diff --git a/KLR/TGR/AST.lean b/KLR/TGR/AST.lean index 55e9eda2..e590c895 100644 --- a/KLR/TGR/AST.lean +++ b/KLR/TGR/AST.lean @@ -30,7 +30,7 @@ deriving Inhabited, Repr, Nonempty abbrev Var := String --- scalar-scalar binary operators +/- scalar-scalar binary operators -/ inductive BinaryOp where | add | sub @@ -41,7 +41,7 @@ inductive BinaryOp where | cmp deriving Inhabited, Repr --- scalar unary operators +/- scalar unary operators -/ inductive UnaryOp where | exp | sqrt @@ -60,33 +60,31 @@ the output shape information exists in two redundant places: in the `Statement` and in the `Operator`. -/ inductive Operator where - -- An argument to the function, identified by its index. + /- An argument to the function, identified by its index. -/ | arg (index : Nat) - -- apply a binary operator element-wise to two tensors + /- apply a binary operator element-wise to two tensors -/ | binaryOp (op : BinaryOp) (a b : Var) - -- apply a unary operator element-wise to a tensor + /- apply a unary operator element-wise to a tensor -/ | unaryOp (op : UnaryOp) (a : Var) - -- apply a reduction operation to a tensor, reducing it along the specified dimensions + /- apply a reduction operation to a tensor, reducing it along the specified dimensions -/ | reductionOp (op : BinaryOp) (a b : Var) (dim : List Nat) - -- perform a batch matrix multiplication on two tensors. - -- Specifically, computes the einsum bij,bkj->bik + /- perform a batch matrix multiplication on two tensors. + Specifically, computes the einsum bij,bkj->bik -/ | batchMatmul (a b : Var) - -- create a tensor with a range of values within the given limits and with the specified stride + /- create a tensor with a range of values within the given limits and with the specified stride -/ | arange (start : Nat) (stop : Nat) (step : Nat) (shape : Shape) - -- concatenate a list of tensors along the specified dimension + /- concatenate a list of tensors along the specified dimension -/ | concat (tensors : List Var) (dim : Nat) - -- select elements from two tensors based on a condition tensor + /- select elements from two tensors based on a condition tensor -/ | select (cond a b : Var) - -- create a tensor filled with a specific value, with the given shape - -- Note: the tensor is always a scalar-array + /- create a tensor filled with a specific value, with the given shape + Note: the tensor is always a scalar-array -/ | full (value : Tensor) (shape : Shape) - -- transpose a tensor with the provided permutation of dimensions + /- transpose a tensor with the provided permutation of dimensions -/ | transpose (a : Var) (dims : List Nat) - -- unused - | split_with_sizes (a : Var) (sizes : List Nat) -- ?? - -- reshape a tensor to the specified shape + /- reshape a tensor to the specified shape -/ | reshape (a : Var) (shape : Shape) /- broadcast a tensor to the specified shape @@ -97,15 +95,15 @@ inductive Operator where existing ones of size 1. -/ | broadcast (a : Var) (shape : Shape) - -- create a constant tensor with the given values and shape + /- create a constant tensor with the given values and shape -/ | const (values : Tensor) (shape : Shape) (dtype : Dtype) - -- gather elements from a tensor using the provided indices and offset dimensions - -- TODO: gather is complicated and not used except for in llama, so for now - -- we just pass through the semantics of HLO's gather + /- gather elements from a tensor using the provided indices and offset dimensions + TODO: gather is complicated and not used except for in llama, so for now + we just pass through the semantics of HLO's gather -/ | gather (input indices : Var) (offsetDims collapsedSliceDims startIndexMap : List Nat) (indexVectorDim : Nat) - -- slice a tensor along specified dimensions, with start, limit, and stride + /- slice a tensor along specified dimensions, with start, limit, and stride -/ | slice (a : Var) (slice : List Slice) - -- call another function, passing input values and receiving outputs + /- call another function, passing input values and receiving outputs -/ | call (callee : String) (inputValues : List Var) deriving Inhabited, Repr @@ -114,7 +112,7 @@ A statement in TGR (Tensor Graph Representation). In SSA form, so each variable is assigned exactly once. -/ inductive Statement where - -- A comment in the code, for making the dumped IR readable + /- A comment in the code, for making the dumped IR readable -/ | comment (msg : String) /- Assign the result of `op` to `dest` , with resulting shape `shape` @@ -123,7 +121,7 @@ inductive Statement where operator, to avoid having to recompute it with fallible operations later. -/ | assign (dest : Var) (op : Operator) (shape : TensorTy) - -- Return variables `vars` from the function + /- Return variables `vars` from the function -/ | ret (vars : List Var) deriving Inhabited, Repr @@ -138,12 +136,12 @@ structure Function where statements : List Statement deriving Inhabited, Repr, Nonempty --- An TGR program +/- A TGR program -/ structure Program where functions : List Function deriving Inhabited, Repr, Nonempty --- Returns the list of variables that this operator immediately depends on. +/- Returns the list of variables that this operator immediately depends on. -/ def dependencies : Operator → List Var | .arg _ => [] | .binaryOp _ a b => [a, b] @@ -155,7 +153,6 @@ def dependencies : Operator → List Var | .select cond a b => [cond, a, b] | .full .. => [] | .transpose a _ => [a] - | .split_with_sizes a _ => [a] | .reshape a _ => [a] | .broadcast a .. => [a] | .const .. => [] @@ -163,19 +160,19 @@ def dependencies : Operator → List Var | .slice a .. => [a] | .call _ inputs => inputs --- Returns the list of all variables defined in this function. +/- Returns the list of all variables defined in this function. -/ def vars (f : Function) : List Var := f.statements.filterMap (fun | .assign dest .. => .some dest | _ => .none) --- Finds the operator that assigns to a variable in the function. +/- Finds the operator that assigns to a variable in the function. -/ def findVar (f : Function) (v : Var) : Option Operator := f.statements.findSome? (fun | .assign dest op _ => if dest == v then .some op else .none | _ => .none) --- TODO: move these toString instances to the TensorLib repo +/- TODO: move these toString instances to the TensorLib repo -/ instance : ToString Slice where toString s := let {start, stop, step, ..} := s @@ -235,7 +232,6 @@ instance : ToString Operator where | .select cond a b => s!"select({cond}, {a}, {b})" | .full v shape => s!"full({repr v}, shape={shape})" | .transpose a dims => s!"transpose({a}, dims={dims})" - | .split_with_sizes a sizes => s!"split_with_sizes({a}, sizes={sizes})" | .reshape a shape => s!"reshape({a}, shape={shape})" | .broadcast a shape => s!"broadcast({a}, shape={shape})" | .const t shape dtype => s!"const({repr t}, shape={shape}, dtype={dtype})" @@ -264,7 +260,7 @@ instance : ToString Program where let functionsStr := p.functions.map toString |> "\n".intercalate s!"# Program\n" ++ functionsStr --- Human readable name for the operator. +/- Human readable name for the operator. -/ def opName : Operator → String | .arg _ => s!"arg" | .binaryOp binOp .. => s!"{binOp}" @@ -276,7 +272,6 @@ def opName : Operator → String | .select .. => s!"select" | .full .. => s!"full" | .transpose .. => s!"transpose" - | .split_with_sizes .. => s!"split_with_sizes" | .reshape .. => s!"reshape" | .broadcast .. => s!"broadcast" | .const .. => s!"const" diff --git a/KLR/TGR/Compile.lean b/KLR/TGR/Compile.lean index fba10586..eae5785b 100644 --- a/KLR/TGR/Compile.lean +++ b/KLR/TGR/Compile.lean @@ -16,28 +16,28 @@ import TensorLib.ByteArray open TensorLib (Dtype Shape Tensor) --- This module compiles a StableHLO program into an TGR program. +/- This module compiles a StableHLO program into an TGR program. -/ namespace KLR.TGR.Compile abbrev SymbolEnv := List (String × Nat) --- Context for the compilation process, to be stored in a state monad. +/- Context for the compilation process, to be stored in a state monad. -/ structure Ctx where - -- the program being compiled + /- the program being compiled -/ program : Program - -- the log of messages generated during compilation (for debugging) + /- the log of messages generated during compilation (for debugging) -/ log : List String gensymEnv : SymbolEnv deriving Inhabited, Repr --- Compilation requires tracking state and also potentially returning errors. +/- Compilation requires tracking state and also potentially returning errors. -/ abbrev Compile T := StM Ctx T --- Emit a message to the compilation log. +/- Emit a message to the compilation log. -/ def log (msg : String) : Compile Unit := modify (fun ctx => { ctx with log := ctx.log ++ [msg]}) --- Add a function to the program being compiled. +/- Add a function to the program being compiled. -/ def addFunction (func : Function) : Compile Unit := do modify (fun ctx => { ctx with program := { ctx.program with functions := ctx.program.functions ++ [func] } }) @@ -57,7 +57,7 @@ def gensym (name : String) : Compile Var := do modify (fun ctx => { ctx with gensymEnv := (name, idx) :: ctx.gensymEnv }) pure s!"{name}_{idx}" --- Permute `l` according to the indices in `permutation`. +/- Permute `l` according to the indices in `permutation`. -/ def permute {T : Type} (l : List T) (permutation : List Nat) : Option (List T) := permutation.mapM fun dim => l[dim]? @@ -117,7 +117,7 @@ into a tensor of int32 values. Assumes bytes are in big-endian order. -/ def parseInt32TensorFromHex (s : String) : Compile Tensor := do - -- Tail recursive helper to convert a list of hex characters to a list of BitVec 32 values. + /- Tail recursive helper to convert a list of hex characters to a list of BitVec 32 values. -/ let rec toBitVec32List (str : List Char) (acc : List (BitVec 32)) : Compile (List (BitVec 32)):= do match str with | c0 :: c1 :: c2 :: c3 :: c4 :: c5 :: c6 :: c7 :: rest => @@ -127,11 +127,11 @@ def parseInt32TensorFromHex (s : String) : Compile Tensor := do | [] => pure acc.reverse | _ => throw s!"Hex string must have a number of characters divisible by 8, but got {str}." - -- Trim off the leading "0x" and convert the rest to a list of BitVec 32 values. + /- Trim off the leading "0x" and convert the rest to a list of BitVec 32 values. -/ let bvs ← match s.toList with | '0' :: 'x' :: rest => toBitVec32List rest [] | _ => throw s!"Hex string must start with '0x', but got {s}." - -- Concatenate the bitvecs to create a little-endian bytearray + /- Concatenate the bitvecs to create a little-endian bytearray -/ let data := bvs.foldr (fun new acc => (TensorLib.toLEByteArray new) ++ acc) ByteArray.empty pure { dtype := TensorLib.Dtype.int32, @@ -139,7 +139,7 @@ def parseInt32TensorFromHex (s : String) : Compile Tensor := do data } --- Convert a list of tensors to a single tensor by concatenating them along a new first dimension. +/- Convert a list of tensors to a single tensor by concatenating them along a new first dimension. -/ def tensorOfTensorList (ns : List Tensor) : Compile Tensor := match ns with | f :: r => do @@ -167,20 +167,20 @@ Note that the meaning of string literals in StableHLO is not well-defined, but in practice JAX uses them to hex encode integer tensors. -/ def parseStringLiteral := parseInt32TensorFromHex --- Parse a StableHLO boolean literal to a Bool. +/- Parse a StableHLO boolean literal to a Bool. -/ def parseBoolLiteral : StableHLO.Parsing.BooleanLiteral → Bool | .true => true | .false => false --- Parse a StableHLO element literal to an TGR tensor. +/- Parse a StableHLO element literal to an TGR tensor. -/ def parseElementLiteral : StableHLO.Parsing.ElementLiteral → Compile TensorLib.Tensor | .floatLiteral f => f |> parseFloat |> TensorLib.Tensor.arrayScalarFloat32! |> pure | .stringLiteral s => s |> parseStringLiteral | .booleanLiteral b => b |> parseBoolLiteral |> TensorLib.Tensor.arrayScalarBool! |> pure | .complexLiteral _ => impossible "unimplemented" --- Parse a StableHLO dense literal to an TGR tensor. +/- Parse a StableHLO dense literal to an TGR tensor. -/ def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor - -- special case for singleton tensors so we don't create an extra dimension + /- special case for singleton tensors so we don't create an extra dimension -/ | .denseElements [v] => do parseElementLiteral v | .denseElements l => do @@ -188,7 +188,7 @@ def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor | .denseDimension l => do tensorOfTensorList (← l.mapM parseTensorLiteral) --- Convert a StableHLO tensor type to an TGR TensorTy. +/- Convert a StableHLO tensor type to an TGR TensorTy. -/ def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do let shape ← t.shape.mapM (fun | .known d => pure d @@ -208,19 +208,19 @@ def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do | _ => throw s!"Unsupported tensor element type: {repr t.tensorElementTypeGen}" pure (.mk (.mk shape) dtype) --- Parse an TGR TensorTy at index `n` from the list of types. +/- Parse an TGR TensorTy at index `n` from the list of types. -/ def parseTensorTypeFromValueTypes (l : List StableHLO.Parsing.ValueType) (n : Nat): Compile TensorTy := match l[n]? with | .some (.tensorType t) => parseTensorType t | .some t => throw s!"Element {n} of type list must have tensor type, but got {repr t}." | _ => throw s!"Type list must have at least {n + 1} values, but got only {l.length}." --- Parse an TGR TensorTy from a list of types, expecting the list to have exactly one element. +/- Parse an TGR TensorTy from a list of types, expecting the list to have exactly one element. -/ def parseSingleTensorTypeFromValueTypes : List StableHLO.Parsing.ValueType → Compile TensorTy | [.tensorType t] => parseTensorType t | t => throw s!"Expected type list to have a single tensor type, but got {repr t}." --- Parse an array from a StableHLO literal. +/- Parse an array from a StableHLO literal. -/ def parseArray (c : StableHLO.Parsing.Literal) : Compile (List Nat) := match c with | .array (.array64 arr) => pure (arr.map fun ⟨ _sign, n ⟩ => n) @@ -243,31 +243,31 @@ def parseNatFromElementLiteral (c : StableHLO.Parsing.Literal) : Compile Nat := | .element (.floatLiteral l) => throw s!"Got unsupported float literal {repr l}." | l => throw s!"Expected a float literal but got {repr l}." --- Find an attribute by name in a list of attributes +/- Find an attribute by name in a list of attributes -/ def lookupAttribute (attrs : List StableHLO.Parsing.Attribute) (name : String) : Compile StableHLO.Parsing.Constant := match attrs.find? (fun ⟨ id, _ ⟩ => id == name) with | some ⟨ _, attr ⟩ => pure attr | none => throw s!"Attribute '{name}' not found." --- Find an attribute by name in a list of attributes, returning only the associated literal, not its type +/- Find an attribute by name in a list of attributes, returning only the associated literal, not its type -/ def lookupAttributeValue (attrs : List StableHLO.Parsing.Attribute) (name : String) : Compile StableHLO.Parsing.Literal := lookupAttribute attrs name |>.map (fun ⟨ lit, _ ⟩ => lit) --- Get the value of a field in a StableHLO record, expecting it to be a list of integers. +/- Get the value of a field in a StableHLO record, expecting it to be a list of integers. -/ def lookupNatsInFields (fields : List StableHLO.Parsing.StableHLORecordField) (name : String) : Compile (List Nat) := match fields.find? (fun ⟨ n, _ ⟩ => n == name) with | some (.mk _ (.many ns)) => pure ns | some v => throw s!"Field '{name}' must be a list of integers, but got {repr v}." | none => pure [] --- Get the value of a field in a StableHLO record, expecting it to be a single integer. +/- Get the value of a field in a StableHLO record, expecting it to be a single integer. -/ def lookupNatInFields (fields : List StableHLO.Parsing.StableHLORecordField) (name : String) : Compile Nat := match fields.find? (fun ⟨ n, _ ⟩ => n == name) with | some (.mk _ (.one n)) => pure n | some v => throw s!"Field '{name}' must be a single integer, but got {repr v}." | none => throw s!"Field '{name}' not found in record list {repr fields}." --- extract the arguments to the `dotGeneral` operation from a record in the list of attributes +/- extract the arguments to the `dotGeneral` operation from a record in the list of attributes -/ def extractDotDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile (List Nat × List Nat × List Nat × List Nat) := do let dotAttr ← lookupAttributeValue attrs "dot_dimension_numbers" match dotAttr with @@ -279,7 +279,7 @@ def extractDotDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Comp pure (lhs_batching_dims, lhs_contracting_dims, rhs_batching_dims, rhs_contracting_dims) | _ => throw "Attribute 'dot_dimension_numbers' must be a stableHLORecord." --- extract the arguments to the `gather` operation from a record in the list of attributes +/- extract the arguments to the `gather` operation from a record in the list of attributes -/ def extractDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile (List Nat × List Nat × List Nat × Nat) := do let attr ← lookupAttributeValue attrs "dimension_numbers" match attr with @@ -325,27 +325,27 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := log s!"Compiling operation {repr op}" match op with | .stablehlo opCode inputValues inputFunctions inputAttributes outputs signature => do - -- Reuse the variable names and shapes from the StableHLO program + /- Reuse the variable names and shapes from the StableHLO program -/ let output ← match outputs with | [output] => pure output | _ => throw "Operator signature must have a single output." let outputTy ← parseSingleTensorTypeFromValueTypes signature.range - -- helper function to emit TGR for element-wise unary ops + /- helper function to emit TGR for element-wise unary ops -/ let makeUnOp := fun (op : UnaryOp) => do let a := inputValues[0]! pure [.assign output (.unaryOp op a) outputTy] - -- helper function to emit TGR for element-wise binary ops + /- helper function to emit TGR for element-wise binary ops -/ let makeBinOp := fun (op : BinaryOp) => do let a := inputValues[0]! let b := inputValues[1]! pure [.assign output (.binaryOp op a b) outputTy] match opCode with - -- element-wise unary operators + /- element-wise unary operators -/ | .sqrt => makeUnOp .sqrt | .negate => makeUnOp .neg | .exponential => makeUnOp .exp | .convert => makeUnOp (UnaryOp.convert outputTy.dtype) - -- element-wise binary operators + /- element-wise binary operators -/ | .compare => makeBinOp .cmp | .multiply => makeBinOp .mul | .add => makeBinOp .add @@ -353,25 +353,25 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := | .maximum => makeBinOp .max | .subtract => makeBinOp .sub | .divide => makeBinOp .div - -- tensor nullary operators + /- tensor nullary operators -/ | .constant => do let valueAttr ← lookupAttributeValue inputAttributes "value" match valueAttr with | (.tensor t) => do let t ← parseTensorLiteral t if t.shape == Shape.empty then - -- If the tensor is a scalar-array, we use a `full` operation - -- to create a tensor of the same shape as the output. + /- If the tensor is a scalar-array, we use a `full` operation + to create a tensor of the same shape as the output. -/ pure [.assign output (.full t outputTy.shape) outputTy] else - -- If the tensor is not a scalar-array, it corresponds to a - -- `const` operation. + /- If the tensor is not a scalar-array, it corresponds to a + `const` operation. -/ if t.shape.count != outputTy.shape.count then throw s!"Tensor literal shape {t.shape} does not match expected output shape {outputTy.shape}." let t ← t.reshape outputTy.shape pure [.assign output (.const t outputTy.shape outputTy.dtype) outputTy] | _ => throw "Constant operation requires a 'value' attribute with tensor literal." - -- tensor unary operators + /- tensor unary operators -/ | .reshape => do let input := inputValues[0]! pure [.assign output (.reshape input outputTy.shape) outputTy] @@ -410,8 +410,8 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let inputTy := ← parseTensorTypeFromValueTypes signature.domain 0 let broadcastDims ← (← lookupAttributeValue inputAttributes "broadcast_dimensions") |> parseArray let reshaped ← gensym (input ++ "_reshaped") - -- A shape that has the same number of dimensions as the output, but where - -- specified dimensions match the input shape, and others are 1. + /- A shape that has the same number of dimensions as the output, but where + specified dimensions match the input shape, and others are 1. -/ let newShape := outputTy.shape.ndim |> List.range |> List.map (fun n => if let .some i := broadcastDims.idxOf? n then inputTy.shape.val[i]! @@ -426,7 +426,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let input := inputValues[0]! let dims ← (← lookupAttributeValue inputAttributes "permutation") |> parseArray pure [.assign output (.transpose input dims) outputTy] - -- tensor binary operators + /- tensor binary operators -/ | .dotGeneral => do /- The semantics of the `dotGeneral` operation are complex, see the @@ -434,7 +434,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := variables in the spec to aid comprehension. https://github.com/openxla/stablehlo/blob/6f7b4ab8f96dc65cf3c8e9824836117d2934cc45/docs/spec.md?#dot_general -/ - -- Gather metadata from the inputs + /- Gather metadata from the inputs -/ let (lhsBatchingDims, lhsContractingDims, rhsBatchingDims, rhsContractingDims) ← extractDotDimensionNumbers inputAttributes let lhs := inputValues[0]! @@ -446,7 +446,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let rhsShape := rhsType.shape let lhsDims := List.range (TensorLib.Shape.ndim lhsShape) let rhsDims := List.range (TensorLib.Shape.ndim rhsShape) - -- Calculate shapes of intermediate tensors and output + /- Calculate shapes of intermediate tensors and output -/ let lhsResultDims := lhsDims.filter (fun i => !lhsBatchingDims.contains i && !lhsContractingDims.contains i) let rhsResultDims := rhsDims.filter (fun i => !rhsBatchingDims.contains i && !rhsContractingDims.contains i) let lhsTransposePerm := lhsBatchingDims ++ lhsResultDims ++ lhsContractingDims @@ -463,7 +463,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let lhsResultSize := if lhsResultShape.isEmpty then 1 else lhsResultShape.foldl (· * ·) (1 : Nat) let rhsResultSize := if rhsResultShape.isEmpty then 1 else rhsResultShape.foldl (· * ·) (1 : Nat) let contractingSize := if contractingShape.isEmpty then 1 else contractingShape.foldl (· * ·) 1 - -- Create fresh variable names for intermediate results + /- Create fresh variable names for intermediate results -/ let lhsTransposedName ← gensym (lhs ++ "_transposed") let rhsTransposedName ← gensym (rhs ++ "_transposed") let lhsReshapedName ← gensym (lhs ++ "_reshaped") @@ -475,7 +475,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let resultReshapedName ← gensym (output ++ "_reshaped") let resultReshapedShape := [batchSize, lhsResultSize, rhsResultSize] let resultReshapedType := TensorTy.mk (.mk resultReshapedShape) dtype - -- Emit the TGR statements for the dotGeneral operation + /- Emit the TGR statements for the dotGeneral operation -/ pure ([ .comment "Dot General Operation", .assign lhsTransposedName (.transpose lhs lhsTransposePerm) (.mk (.mk lhsTransposedShape) dtype), @@ -489,7 +489,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := let tensors := inputValues let dim ← (← lookupAttributeValue inputAttributes "dimension") |> parseNatFromElementLiteral pure [.assign output (.concat tensors dim) outputTy] - -- tensor ternary operators + /- tensor ternary operators -/ | .select => let cond := inputValues[0]! let a := inputValues[1]! @@ -519,8 +519,8 @@ def compileFunc (f : StableHLO.Parsing.Function) : Compile Unit := do let outputs ← f.funcType.range.mapM fun | .tensorType t => parseTensorType t | _ => throw "Function output must be a tensor type." - -- Since arguments are referred to by index, emit a statement for each - -- argument that assigns it to a named variable + /- Since arguments are referred to by index, emit a statement for each + argument that assigns it to a named variable -/ let preamble ← args.mapIdxM (fun i ⟨ name, v ⟩ => do match v with | .tensorType t => diff --git a/KLR/TGR/Dot.lean b/KLR/TGR/Dot.lean index cd0dd39f..83a0926e 100644 --- a/KLR/TGR/Dot.lean +++ b/KLR/TGR/Dot.lean @@ -9,7 +9,7 @@ import SHerLOC.Analysis.Graph open StableHLO.Analysis (Graph Edge Vertex) --- This module provides a way to convert an TGR function into a DOT graph representation. +/- This module provides a way to convert an TGR function into a DOT graph representation. -/ namespace KLR.TGR.Graph /- @@ -82,13 +82,13 @@ we create a separate vertex for each use of a constant. def graph (f : TGR.Function) : Graph := Id.run do let mut vertices := [] let mut edges := [] - -- Every variables in the function that is the result of a `constant` operatior + /- Every variables in the function that is the result of a `constant` operatior -/ let mut consts := f.statements.filterMap (fun | .assign v (.const ..) shape => .some ("const", v, shape) | .assign v (.full ..) shape => .some ("full", v, shape) | _ => .none) - -- A closure that creates edges from a list of inputs to an output variable. - -- If the input is a constant, it creates a vertex for that constant. + /- A closure that creates edges from a list of inputs to an output variable. + If the input is a constant, it creates a vertex for that constant. -/ let (makeEdges : List String → String → (List Vertex) × (List Edge)) := fun inputs output => Id.run do let mut vertices := [] let mut edges := [] @@ -101,7 +101,7 @@ def graph (f : TGR.Function) : Graph := Id.run do edges := (makeEdge (sanitize input) output) :: edges return (vertices, edges) - -- Walk the program statements and create vertices and edges. + /- Walk the program statements and create vertices and edges. -/ for s in f.statements do match s with | .assign _ (.const ..) _ | .assign _ (.full ..) _ => pure () diff --git a/KLR/TGR/Py.lean b/KLR/TGR/Py.lean index 9f4e744b..479256f0 100644 --- a/KLR/TGR/Py.lean +++ b/KLR/TGR/Py.lean @@ -89,7 +89,7 @@ def shapeToPy (s : Shape) : String := s.val.map toString |> ",".intercalate def varToPy (arg : Var) : String := - -- Prefix, since Python variables can't start with a digit + /- Prefix, since Python variables can't start with a digit -/ s!"var_{arg}" def opToPy (op : Operator) : String := @@ -148,7 +148,7 @@ def compileProgram (p : Program) : Format := p.functions.map compileFunction joinSep lines line --- Compile the TGR program to a Python program. +/- Compile the TGR program to a Python program. -/ def compile (p : Program) : String := (compileProgram p).pretty From 27e256bff89b86b484a5e60109e5fa4ae986f3b0 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 23 Jul 2025 10:49:41 -0400 Subject: [PATCH 14/15] Remove redundant args to const --- KLR/TGR/AST.lean | 4 ++-- KLR/TGR/Compile.lean | 2 +- KLR/TGR/Py.lean | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/KLR/TGR/AST.lean b/KLR/TGR/AST.lean index e590c895..27d15ae4 100644 --- a/KLR/TGR/AST.lean +++ b/KLR/TGR/AST.lean @@ -96,7 +96,7 @@ inductive Operator where -/ | broadcast (a : Var) (shape : Shape) /- create a constant tensor with the given values and shape -/ - | const (values : Tensor) (shape : Shape) (dtype : Dtype) + | const (values : Tensor) /- gather elements from a tensor using the provided indices and offset dimensions TODO: gather is complicated and not used except for in llama, so for now we just pass through the semantics of HLO's gather -/ @@ -234,7 +234,7 @@ instance : ToString Operator where | .transpose a dims => s!"transpose({a}, dims={dims})" | .reshape a shape => s!"reshape({a}, shape={shape})" | .broadcast a shape => s!"broadcast({a}, shape={shape})" - | .const t shape dtype => s!"const({repr t}, shape={shape}, dtype={dtype})" + | .const t => s!"const(..., shape={t.shape})" | .gather a indices offsetDims collapsedSliceDims startIndexMap indexVectorDim => s!" gather({a}, indices={indices}, offsetDims={offsetDims}, collapsedSliceDims={collapsedSliceDims}, startIndexMap={startIndexMap}, indexVectorDim={indexVectorDim})" | .slice a slices => s!"slice({a}, {slices})" diff --git a/KLR/TGR/Compile.lean b/KLR/TGR/Compile.lean index eae5785b..d22ac45a 100644 --- a/KLR/TGR/Compile.lean +++ b/KLR/TGR/Compile.lean @@ -369,7 +369,7 @@ def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := if t.shape.count != outputTy.shape.count then throw s!"Tensor literal shape {t.shape} does not match expected output shape {outputTy.shape}." let t ← t.reshape outputTy.shape - pure [.assign output (.const t outputTy.shape outputTy.dtype) outputTy] + pure [.assign output (.const t) outputTy] | _ => throw "Constant operation requires a 'value' attribute with tensor literal." /- tensor unary operators -/ | .reshape => do diff --git a/KLR/TGR/Py.lean b/KLR/TGR/Py.lean index 479256f0..d94aad5a 100644 --- a/KLR/TGR/Py.lean +++ b/KLR/TGR/Py.lean @@ -108,10 +108,9 @@ def opToPy (op : Operator) : String := | .transpose a dims => let dimsStr := dims.map toString |> ", ".intercalate s!"np.transpose({varToPy a}, axes=[{dimsStr}])" - | .split_with_sizes .. => panic! s!"Split with sizes operation not implemented in Python translation" | .reshape a shape => s!"{varToPy a}.reshape({shapeToPy shape})" | .broadcast a shape => s!"np.broadcast_to({varToPy a}, ({shapeToPy shape}))" - | .const _ shape _ => s!"np.random.random(({shapeToPy shape}))" -- TODO: make this use the actual constant value + | .const t => s!"np.random.random(({shapeToPy t.shape}))" -- TODO: make this use the actual constant value | .gather .. => panic! s!"Gather operation not implemented in Python translation" | .slice .. => panic! s!"Slice operation not implemented in Python translation" | .call .. => From 1c50034ff6d46800f5ce496c9381fe39c09a3249 Mon Sep 17 00:00:00 2001 From: Paul Biberstein Date: Wed, 23 Jul 2025 14:36:55 -0400 Subject: [PATCH 15/15] Fix manifest files --- KLR/Util/Gzip/lake-manifest.json | 50 ++++++++++++++++---------------- KLR/lake-manifest.json | 28 +++--------------- 2 files changed, 29 insertions(+), 49 deletions(-) diff --git a/KLR/Util/Gzip/lake-manifest.json b/KLR/Util/Gzip/lake-manifest.json index 8ed88edd..01308411 100644 --- a/KLR/Util/Gzip/lake-manifest.json +++ b/KLR/Util/Gzip/lake-manifest.json @@ -1,7 +1,17 @@ {"version": "1.1.0", "packagesDir": ".lake/packages", "packages": - [{"type": "path", + [{"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": false, + "configFile": "lakefile.lean"}, + {"type": "path", "scope": "", "name": "Util", "manifestFile": "lake-manifest.json", @@ -25,16 +35,6 @@ "inputRev": "v4.21.0", "inherited": false, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover/SHerLOC.git", - "type": "git", - "subDir": null, - "scope": "", - "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", - "name": "SHerLOC", - "manifestFile": "lake-manifest.json", - "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", - "inherited": true, - "configFile": "lakefile.lean"}, {"url": "https://github.com/leanprover/TensorLib.git", "type": "git", "subDir": null, @@ -49,41 +49,41 @@ "type": "git", "subDir": null, "scope": "", - "rev": "2ac43674e92a695e96caac19f4002b25434636da", + "rev": "c4aa78186d388e50a436e8362b947bae125a2933", "name": "plausible", "manifestFile": "lake-manifest.json", - "inputRev": "v4.20.0", + "inputRev": "v4.21.0", "inherited": true, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover-community/import-graph.git", + {"url": "https://github.com/leanprover-community/batteries", "type": "git", "subDir": null, - "scope": "", - "rev": "a11bcb5238149ae5d8a0aa5e2f8eddf8a3a9b27d", - "name": "importGraph", + "scope": "leanprover-community", + "rev": "8d2067bf518731a70a255d4a61b5c103922c772e", + "name": "batteries", "manifestFile": "lake-manifest.json", - "inputRev": "v4.20.0", + "inputRev": "v4.21.0", "inherited": true, "configFile": "lakefile.toml"}, {"url": "https://github.com/leanprover-community/aesop", "type": "git", "subDir": null, "scope": "", - "rev": "ddfca7829bf8aa4083cdf9633935dddbb28b7b2a", + "rev": "8ff27701d003456fd59f13a9212431239d902aef", "name": "aesop", "manifestFile": "lake-manifest.json", - "inputRev": "v4.20.0", + "inputRev": "v4.21.0", "inherited": true, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover-community/batteries", + {"url": "https://github.com/leanprover-community/import-graph.git", "type": "git", "subDir": null, - "scope": "leanprover-community", - "rev": "7a0d63fbf8fd350e891868a06d9927efa545ac1e", - "name": "batteries", + "scope": "", + "rev": "a11bcb5238149ae5d8a0aa5e2f8eddf8a3a9b27d", + "name": "importGraph", "manifestFile": "lake-manifest.json", "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}], "name": "Gzip", - "lakeDir": ".lake"} + "lakeDir": ".lake"} \ No newline at end of file diff --git a/KLR/lake-manifest.json b/KLR/lake-manifest.json index e6ad0719..49c699d0 100644 --- a/KLR/lake-manifest.json +++ b/KLR/lake-manifest.json @@ -51,16 +51,6 @@ "inputRev": "v4.21.0", "inherited": false, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover/lean4-cli.git", - "type": "git", - "subDir": null, - "scope": "", - "rev": "02dbd02bc00ec4916e99b04b2245b30200e200d0", - "name": "Cli", - "manifestFile": "lake-manifest.json", - "inputRev": "v4.19.0", - "inherited": true, - "configFile": "lakefile.toml"}, {"url": "https://github.com/leanprover-community/import-graph.git", "type": "git", "subDir": null, @@ -71,25 +61,15 @@ "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover-community/aesop", + {"url": "https://github.com/leanprover/lean4-cli.git", "type": "git", "subDir": null, "scope": "", - "rev": "ddfca7829bf8aa4083cdf9633935dddbb28b7b2a", - "name": "aesop", - "manifestFile": "lake-manifest.json", - "inputRev": "v4.20.0", - "inherited": true, - "configFile": "lakefile.toml"}, - {"url": "https://github.com/leanprover-community/batteries", - "type": "git", - "subDir": null, - "scope": "leanprover-community", - "rev": "7a0d63fbf8fd350e891868a06d9927efa545ac1e", - "name": "batteries", + "rev": "f9e25dcbed001489c53bceeb1f1d50bbaf7451d4", + "name": "Cli", "manifestFile": "lake-manifest.json", "inputRev": "v4.20.0", "inherited": true, "configFile": "lakefile.toml"}], "name": "Util", - "lakeDir": ".lake"} + "lakeDir": ".lake"} \ No newline at end of file