diff --git a/.gitignore b/.gitignore index 022ef304..b8107b32 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ build/ dist/ klr.bin /.wheel/ +env diff --git a/KLR/KLR/Compile.lean b/KLR/KLR/Compile.lean new file mode 100644 index 00000000..e69de29b diff --git a/KLR/TGR.lean b/KLR/TGR.lean new file mode 100644 index 00000000..0a76b340 --- /dev/null +++ b/KLR/TGR.lean @@ -0,0 +1 @@ +import KLR.TGR.Basic diff --git a/KLR/TGR/AST.lean b/KLR/TGR/AST.lean new file mode 100644 index 00000000..27d15ae4 --- /dev/null +++ b/KLR/TGR/AST.lean @@ -0,0 +1,282 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.Core.Operators +import KLR.Util +import SHerLOC +import TensorLib.Dtype +import TensorLib.Shape +import TensorLib.Slice +import TensorLib.Tensor + +open TensorLib (Tensor Shape Dtype Slice) + +/- +The definition of the Tensor Graph Representation (TGR) IR. The goal of this IR is to +be a uniform representation for graphs of tensor operations, which we can use as a +common compilation target for different frontends (e.g. StableHLO, PyTorch FX, etc.). +A TGR program consists of a list of functions, each with a name, and input and output tensors. +The function body is in SSA, with each operation producing a single output tensor. +-/ +namespace KLR.TGR + +structure TensorTy where + shape : Shape + dtype : Dtype +deriving Inhabited, Repr, Nonempty + +abbrev Var := String + +/- scalar-scalar binary operators -/ +inductive BinaryOp where + | add + | sub + | mul + | div + | and + | max + | cmp +deriving Inhabited, Repr + +/- scalar unary operators -/ +inductive UnaryOp where + | exp + | sqrt + | neg + | convert (dtype : Dtype) +deriving Inhabited, Repr + +/- +Operators in the TGR (Tensor Graph Representation) IR. + +Note: some HLO operations have "load-bearing" output shapes, meaning the +output shape is a vital part of the operation's semantics (e.g. `reshape`). +For these operators, we store the output shape in the `Operator`, even +though this means that when considering an `Operator` as part of a `Statement`, +the output shape information exists in two redundant places: in the `Statement` +and in the `Operator`. +-/ +inductive Operator where + /- An argument to the function, identified by its index. -/ + | arg (index : Nat) + + /- apply a binary operator element-wise to two tensors -/ + | binaryOp (op : BinaryOp) (a b : Var) + /- apply a unary operator element-wise to a tensor -/ + | unaryOp (op : UnaryOp) (a : Var) + /- apply a reduction operation to a tensor, reducing it along the specified dimensions -/ + | reductionOp (op : BinaryOp) (a b : Var) (dim : List Nat) + + /- perform a batch matrix multiplication on two tensors. + Specifically, computes the einsum bij,bkj->bik -/ + | batchMatmul (a b : Var) + /- create a tensor with a range of values within the given limits and with the specified stride -/ + | arange (start : Nat) (stop : Nat) (step : Nat) (shape : Shape) + /- concatenate a list of tensors along the specified dimension -/ + | concat (tensors : List Var) (dim : Nat) + /- select elements from two tensors based on a condition tensor -/ + | select (cond a b : Var) + /- create a tensor filled with a specific value, with the given shape + Note: the tensor is always a scalar-array -/ + | full (value : Tensor) (shape : Shape) + /- transpose a tensor with the provided permutation of dimensions -/ + | transpose (a : Var) (dims : List Nat) + /- reshape a tensor to the specified shape -/ + | reshape (a : Var) (shape : Shape) + /- + broadcast a tensor to the specified shape + + It must be the case that `len(shape(a)) = len(shape)` and that + `∀ i, shape(a)[i] != shape[i] => shape(a)[i] == 1` + In other words, the broadcast cannot produce new dimensions, only expand + existing ones of size 1. + -/ + | broadcast (a : Var) (shape : Shape) + /- create a constant tensor with the given values and shape -/ + | const (values : Tensor) + /- gather elements from a tensor using the provided indices and offset dimensions + TODO: gather is complicated and not used except for in llama, so for now + we just pass through the semantics of HLO's gather -/ + | gather (input indices : Var) (offsetDims collapsedSliceDims startIndexMap : List Nat) (indexVectorDim : Nat) + /- slice a tensor along specified dimensions, with start, limit, and stride -/ + | slice (a : Var) (slice : List Slice) + /- call another function, passing input values and receiving outputs -/ + | call (callee : String) (inputValues : List Var) +deriving Inhabited, Repr + +/- +A statement in TGR (Tensor Graph Representation). +In SSA form, so each variable is assigned exactly once. +-/ +inductive Statement where + /- A comment in the code, for making the dumped IR readable -/ + | comment (msg : String) + /- + Assign the result of `op` to `dest` , with resulting shape `shape` + + Note: We store the shape directly, even though it is inferrable based on the, + operator, to avoid having to recompute it with fallible operations later. + -/ + | assign (dest : Var) (op : Operator) (shape : TensorTy) + /- Return variables `vars` from the function -/ + | ret (vars : List Var) +deriving Inhabited, Repr + +/- +An TGR function. Note that arguments are referred to by index, so +we only store the argument shapes, not names. +-/ +structure Function where + name : String + inputs : List TensorTy + outputs : List TensorTy + statements : List Statement +deriving Inhabited, Repr, Nonempty + +/- A TGR program -/ +structure Program where + functions : List Function +deriving Inhabited, Repr, Nonempty + +/- Returns the list of variables that this operator immediately depends on. -/ +def dependencies : Operator → List Var + | .arg _ => [] + | .binaryOp _ a b => [a, b] + | .unaryOp _ a => [a] + | .reductionOp _ a b _ => [a, b] + | .batchMatmul a b => [a, b] + | .arange .. => [] + | .concat tensors _ => tensors + | .select cond a b => [cond, a, b] + | .full .. => [] + | .transpose a _ => [a] + | .reshape a _ => [a] + | .broadcast a .. => [a] + | .const .. => [] + | .gather a i .. => [a, i] + | .slice a .. => [a] + | .call _ inputs => inputs + +/- Returns the list of all variables defined in this function. -/ +def vars (f : Function) : List Var := + f.statements.filterMap (fun + | .assign dest .. => .some dest + | _ => .none) + +/- Finds the operator that assigns to a variable in the function. -/ +def findVar (f : Function) (v : Var) : Option Operator := + f.statements.findSome? (fun + | .assign dest op _ => if dest == v then .some op else .none + | _ => .none) + +/- TODO: move these toString instances to the TensorLib repo -/ +instance : ToString Slice where + toString s := + let {start, stop, step, ..} := s + let start := start.map toString |>.getD "" + let stop := stop.map toString |>.getD "" + let step := step.map toString |>.getD "" + s!"{start}:{stop}:{step}" + +instance : ToString Shape where + toString s := + s.val.map toString |> "x".intercalate |> fun x => s!"[{x}]" + +instance : ToString Dtype where + toString + | .bool => "bool" + | .int8 => "i8" + | .int16 => "i16" + | .int32 => "i32" + | .int64 => "i64" + | .uint8 => "u8" + | .uint16 => "u16" + | .uint32 => "u32" + | .uint64 => "u64" + | .float32 => "f32" + | .float64 => "f64" + +instance : ToString BinaryOp where + toString + | .add => "add" + | .sub => "sub" + | .mul => "mul" + | .div => "div" + | .and => "and" + | .max => "max" + | .cmp => "cmp" + +instance : ToString UnaryOp where + toString + | .exp => "exp" + | .sqrt => "sqrt" + | .neg => "neg" + | .convert dtype => s!"convert_{dtype}" + +instance : ToString TensorTy where + toString (t : TensorTy) : String := + s!"{t.shape}{t.dtype}" + +instance : ToString Operator where + toString + | .arg n => s!"arg({n})" + | .binaryOp binOp a b => s!"{binOp}({a}, {b})" + | .unaryOp unOp a => s!"{unOp}({a})" + | .reductionOp redOp a b dim => s!"reduce-{redOp}({a}, {b}, dim={dim})" + | .batchMatmul a b => s!"matmul({a}, {b})" + | .arange start stop step shape => s!"arange({start}, {stop}, {step}, shape={shape})" + | .concat tensors dim => s!"concat({", ".intercalate tensors}, dim={dim})" + | .select cond a b => s!"select({cond}, {a}, {b})" + | .full v shape => s!"full({repr v}, shape={shape})" + | .transpose a dims => s!"transpose({a}, dims={dims})" + | .reshape a shape => s!"reshape({a}, shape={shape})" + | .broadcast a shape => s!"broadcast({a}, shape={shape})" + | .const t => s!"const(..., shape={t.shape})" + | .gather a indices offsetDims collapsedSliceDims startIndexMap indexVectorDim + => s!" gather({a}, indices={indices}, offsetDims={offsetDims}, collapsedSliceDims={collapsedSliceDims}, startIndexMap={startIndexMap}, indexVectorDim={indexVectorDim})" + | .slice a slices => s!"slice({a}, {slices})" + | .call callee inputValues => + let inputsStr := inputValues.map toString |> ", ".intercalate + s!"call({callee}, inputs=[{inputsStr}])" + +instance : ToString Statement where + toString + | .comment msg => s!"# {msg}" + | .assign dest op shape => s!"{dest} : {toString shape} = {op}" + | .ret name => s!"return {name}" + +instance : ToString Function where + toString f := + let inputsStr := f.inputs.map toString |> ", ".intercalate + let outputsStr := f.outputs.map toString |> ", ".intercalate + let statementsStr := f.statements.map toString |> "\n".intercalate + s!"def {f.name}({inputsStr}) -> ({outputsStr}):\n{statementsStr}" + +instance : ToString Program where + toString p := + let functionsStr := p.functions.map toString |> "\n".intercalate + s!"# Program\n" ++ functionsStr + +/- Human readable name for the operator. -/ +def opName : Operator → String + | .arg _ => s!"arg" + | .binaryOp binOp .. => s!"{binOp}" + | .unaryOp unOp .. => s!"{unOp}" + | .reductionOp redOp .. => s!"{redOp}" + | .batchMatmul .. => s!"batchMatmul" + | .arange .. => s!"arange" + | .concat .. => s!"concat" + | .select .. => s!"select" + | .full .. => s!"full" + | .transpose .. => s!"transpose" + | .reshape .. => s!"reshape" + | .broadcast .. => s!"broadcast" + | .const .. => s!"const" + | .gather .. => s!"gather" + | .slice .. => s!"slice" + | .call callee .. => s!"call {callee}" + +end KLR.TGR diff --git a/KLR/TGR/Basic.lean b/KLR/TGR/Basic.lean new file mode 100644 index 00000000..d3985ea8 --- /dev/null +++ b/KLR/TGR/Basic.lean @@ -0,0 +1,4 @@ +import KLR.TGR.AST +import KLR.TGR.Compile +import KLR.TGR.Dot +import KLR.TGR.Py diff --git a/KLR/TGR/Compile.lean b/KLR/TGR/Compile.lean new file mode 100644 index 00000000..d22ac45a --- /dev/null +++ b/KLR/TGR/Compile.lean @@ -0,0 +1,544 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.Core.Operators +import KLR.TGR.AST +import KLR.Util +import SHerLOC +import TensorLib.Shape +import TensorLib.Slice +import TensorLib.Tensor +import TensorLib.Bytes +import TensorLib.ByteArray + +open TensorLib (Dtype Shape Tensor) + +/- This module compiles a StableHLO program into an TGR program. -/ +namespace KLR.TGR.Compile + +abbrev SymbolEnv := List (String × Nat) + +/- Context for the compilation process, to be stored in a state monad. -/ +structure Ctx where + /- the program being compiled -/ + program : Program + /- the log of messages generated during compilation (for debugging) -/ + log : List String + gensymEnv : SymbolEnv +deriving Inhabited, Repr + +/- Compilation requires tracking state and also potentially returning errors. -/ +abbrev Compile T := StM Ctx T + +/- Emit a message to the compilation log. -/ +def log (msg : String) : Compile Unit := + modify (fun ctx => { ctx with log := ctx.log ++ [msg]}) + +/- Add a function to the program being compiled. -/ +def addFunction (func : Function) : Compile Unit := do + modify (fun ctx => + { ctx with program := { ctx.program with functions := ctx.program.functions ++ [func] } }) + +/- +Generate a fresh variable name based on a given name. + +TODO: This does not actually guarantee that the name is unique, since it +only checks the gensymEnv. We also need to check against all existing +variables in the program. +-/ +def gensym (name : String) : Compile Var := do + let ctx ← get + let idx := match ctx.gensymEnv.find? (fun ⟨ n, _ ⟩ => n == name) with + | some (_, i) => i + 1 + | none => 0 + modify (fun ctx => { ctx with gensymEnv := (name, idx) :: ctx.gensymEnv }) + pure s!"{name}_{idx}" + +/- Permute `l` according to the indices in `permutation`. -/ +def permute {T : Type} (l : List T) (permutation : List Nat) : Option (List T) := + permutation.mapM fun dim => l[dim]? + +/- +Parses a StableHLO floatliteral to a Float32. + +TODO: Probably has all sorts of rounding errors, but dealing with floats in Lean +is so agonizing, and the semantics of this number storage system are so confusing, +that I'm satisfied with this as a first pass for now. + +Examples: +The FloatLiteral + intPart: {.minus, 4} + fracPart: {.plus, 785980} + sciPart: {.minus, 1} +Represents the number 4.4785980e-1, which is + (4 + 0.4785980) * 10 ^ -1 + +Alternatively, if we have + intPart: {.minus, 3} + fracPart: {.plus, 597620} + sciPart: {.minus, 3} +Then it represents the number -3.597620e-3, which is + (-3 + 0.597620) * 10 ^ -3 +-/ +def parseFloat (c : StableHLO.Parsing.FloatLiteral) : Float32 := + match c with + | .decimal { + integerPart := ⟨ intSign, intDecimal ⟩, + fractionalPart := ⟨ _, fracDecimal ⟩, + scientificPart := ⟨ sciSign, sciDecimal ⟩ + } => + let exponent := match sciSign with + | .plus => Int.ofNat sciDecimal + | .minus => -1 * Int.ofNat sciDecimal + let sign := match intSign with + | .plus => 1 + | .minus => (-1 : Float32) + let mantissa := String.toNat! s!"{intDecimal}{fracDecimal}" + let fracDigits := ToString.toString fracDecimal |>.length |> Int.ofNat + let exponent := exponent - fracDigits + let (exponentSign, exponent) := if exponent >= 0 then + (false, exponent.natAbs) + else + (true, exponent.natAbs) + sign * OfScientific.ofScientific mantissa exponentSign exponent + | .hexaDecimal n => UInt32.ofNat n |> Float32.ofBits + +#guard parseFloat (.decimal { integerPart := ⟨ .plus, 4 ⟩, fractionalPart := ⟨ .plus, 785980 ⟩, scientificPart := ⟨ .plus, 3 ⟩}) == 4785.980 +#guard parseFloat (.decimal { integerPart := ⟨ .minus, 4 ⟩, fractionalPart := ⟨ .plus, 785980 ⟩, scientificPart := ⟨ .minus, 1 ⟩}) == -0.4785980 +#guard parseFloat (.decimal { integerPart := ⟨ .minus, 3 ⟩, fractionalPart := ⟨ .plus, 597620 ⟩, scientificPart := ⟨ .minus, 3 ⟩}) == -0.003597620 + +/- +Parses a hex string (starting with "0x" and preceded by a multipleof 8 hex characters) +into a tensor of int32 values. + +Assumes bytes are in big-endian order. +-/ +def parseInt32TensorFromHex (s : String) : Compile Tensor := do + /- Tail recursive helper to convert a list of hex characters to a list of BitVec 32 values. -/ + let rec toBitVec32List (str : List Char) (acc : List (BitVec 32)) : Compile (List (BitVec 32)):= do + match str with + | c0 :: c1 :: c2 :: c3 :: c4 :: c5 :: c6 :: c7 :: rest => + match KLR.Util.Hex.hexCharsToBitVecBE c0 c1 c2 c3 c4 c5 c6 c7 with + | some v => toBitVec32List rest (v :: acc) + | none => throw s!"Invalid hex character sequence: {c0}{c1}{c2}{c3}." + | [] => pure acc.reverse + | _ => throw s!"Hex string must have a number of characters divisible by 8, but got {str}." + + /- Trim off the leading "0x" and convert the rest to a list of BitVec 32 values. -/ + let bvs ← match s.toList with + | '0' :: 'x' :: rest => toBitVec32List rest [] + | _ => throw s!"Hex string must start with '0x', but got {s}." + /- Concatenate the bitvecs to create a little-endian bytearray -/ + let data := bvs.foldr (fun new acc => (TensorLib.toLEByteArray new) ++ acc) ByteArray.empty + pure { + dtype := TensorLib.Dtype.int32, + shape := Shape.mk [bvs.length], + data + } + +/- Convert a list of tensors to a single tensor by concatenating them along a new first dimension. -/ +def tensorOfTensorList (ns : List Tensor) : Compile Tensor := + match ns with + | f :: r => do + if ! (r.all fun t => t.shape == f.shape) then + throw s!"All tensors in the list must have the same shape, but got {repr ns}." + else if ! (r.all fun t => t.dtype == f.dtype) then + throw s!"All tensors in the list must have the same dtype, but got {repr ns}." + else + let newShape := Shape.mk (ns.length :: f.shape.val) + let dtype := f.dtype + let size := ns.foldl (fun acc t => acc + t.size) 0 + let arr := Tensor.zeros dtype newShape + let mut data := arr.data + let mut posn := 0 + for t in ns do + data := t.data.copySlice 0 data posn (t.data.size * size) + posn := posn + (t.data.size * size) + .ok { arr with data := data } + | [] => pure (TensorLib.Tensor.empty TensorLib.Dtype.float32 (Shape.mk [])) + +/- +Parse a StableHLO string literal to a Tensor + +Note that the meaning of string literals in StableHLO is not well-defined, +but in practice JAX uses them to hex encode integer tensors. +-/ +def parseStringLiteral := parseInt32TensorFromHex +/- Parse a StableHLO boolean literal to a Bool. -/ +def parseBoolLiteral : StableHLO.Parsing.BooleanLiteral → Bool + | .true => true + | .false => false +/- Parse a StableHLO element literal to an TGR tensor. -/ +def parseElementLiteral : StableHLO.Parsing.ElementLiteral → Compile TensorLib.Tensor + | .floatLiteral f => f |> parseFloat |> TensorLib.Tensor.arrayScalarFloat32! |> pure + | .stringLiteral s => s |> parseStringLiteral + | .booleanLiteral b => b |> parseBoolLiteral |> TensorLib.Tensor.arrayScalarBool! |> pure + | .complexLiteral _ => impossible "unimplemented" + +/- Parse a StableHLO dense literal to an TGR tensor. -/ +def parseTensorLiteral : StableHLO.Parsing.DenseLiteral → Compile Tensor + /- special case for singleton tensors so we don't create an extra dimension -/ + | .denseElements [v] => do + parseElementLiteral v + | .denseElements l => do + tensorOfTensorList (← l.mapM parseElementLiteral) + | .denseDimension l => do + tensorOfTensorList (← l.mapM parseTensorLiteral) + +/- Convert a StableHLO tensor type to an TGR TensorTy. -/ +def parseTensorType (t : StableHLO.Parsing.TensorType) : Compile TensorTy := do + let shape ← t.shape.mapM (fun + | .known d => pure d + | .unknown => throw "Can't support tensors with dynamic shape") + let dtype ← match t.tensorElementTypeGen with + | .classic (.floatType .f32) => pure TensorLib.Dtype.float32 + | .classic (.floatType .f64) => pure TensorLib.Dtype.float64 + | .classic (.integerType {sign := .signed, size := .b8}) => pure TensorLib.Dtype.int8 + | .classic (.integerType {sign := .unsigned, size := .b8}) => pure TensorLib.Dtype.uint8 + | .classic (.integerType {sign := .signed, size := .b16}) => pure TensorLib.Dtype.int16 + | .classic (.integerType {sign := .unsigned, size := .b16}) => pure TensorLib.Dtype.uint16 + | .classic (.integerType {sign := .signed, size := .b32}) => pure TensorLib.Dtype.int32 + | .classic (.integerType {sign := .unsigned, size := .b32}) => pure TensorLib.Dtype.uint32 + | .classic (.integerType {sign := .signed, size := .b64}) => pure TensorLib.Dtype.int64 + | .classic (.integerType {sign := .unsigned, size := .b64}) => pure TensorLib.Dtype.uint64 + | .classic (.booleanType) => pure TensorLib.Dtype.bool + | _ => throw s!"Unsupported tensor element type: {repr t.tensorElementTypeGen}" + pure (.mk (.mk shape) dtype) + +/- Parse an TGR TensorTy at index `n` from the list of types. -/ +def parseTensorTypeFromValueTypes (l : List StableHLO.Parsing.ValueType) (n : Nat): Compile TensorTy := + match l[n]? with + | .some (.tensorType t) => parseTensorType t + | .some t => throw s!"Element {n} of type list must have tensor type, but got {repr t}." + | _ => throw s!"Type list must have at least {n + 1} values, but got only {l.length}." + +/- Parse an TGR TensorTy from a list of types, expecting the list to have exactly one element. -/ +def parseSingleTensorTypeFromValueTypes : List StableHLO.Parsing.ValueType → Compile TensorTy + | [.tensorType t] => parseTensorType t + | t => throw s!"Expected type list to have a single tensor type, but got {repr t}." + +/- Parse an array from a StableHLO literal. -/ +def parseArray (c : StableHLO.Parsing.Literal) : Compile (List Nat) := + match c with + | .array (.array64 arr) => pure (arr.map fun ⟨ _sign, n ⟩ => n) + | .array (.array1 _) => throw "array1 unimplemented." + | _ => throw "Expected an array of integers." + +/- +Parse a Nat from a StableHLO float literal. +We need this because integers are often represented as floats in StableHLO. +-/ +def parseNatFromElementLiteral (c : StableHLO.Parsing.Literal) : Compile Nat := + match c with + | .element (.floatLiteral (.decimal {integerPart, fractionalPart, scientificPart})) => + match (fractionalPart.decimal == 0, scientificPart.decimal == 0, integerPart.sign) with + | (true, true, .plus) => pure integerPart.decimal + | (false, _, _) | (_, false, _) => + throw s!"Expected a non-negative integer, but got a float literal with fractional or scientific part: {repr c}." + | (_, _, .minus) => + throw s!"Expected a non-negative integer, but got a float literal with negative sign: {repr c}." + | .element (.floatLiteral l) => throw s!"Got unsupported float literal {repr l}." + | l => throw s!"Expected a float literal but got {repr l}." + +/- Find an attribute by name in a list of attributes -/ +def lookupAttribute (attrs : List StableHLO.Parsing.Attribute) (name : String) : Compile StableHLO.Parsing.Constant := + match attrs.find? (fun ⟨ id, _ ⟩ => id == name) with + | some ⟨ _, attr ⟩ => pure attr + | none => throw s!"Attribute '{name}' not found." + +/- Find an attribute by name in a list of attributes, returning only the associated literal, not its type -/ +def lookupAttributeValue (attrs : List StableHLO.Parsing.Attribute) (name : String) : Compile StableHLO.Parsing.Literal := + lookupAttribute attrs name |>.map (fun ⟨ lit, _ ⟩ => lit) + +/- Get the value of a field in a StableHLO record, expecting it to be a list of integers. -/ +def lookupNatsInFields (fields : List StableHLO.Parsing.StableHLORecordField) (name : String) : Compile (List Nat) := + match fields.find? (fun ⟨ n, _ ⟩ => n == name) with + | some (.mk _ (.many ns)) => pure ns + | some v => throw s!"Field '{name}' must be a list of integers, but got {repr v}." + | none => pure [] + +/- Get the value of a field in a StableHLO record, expecting it to be a single integer. -/ +def lookupNatInFields (fields : List StableHLO.Parsing.StableHLORecordField) (name : String) : Compile Nat := + match fields.find? (fun ⟨ n, _ ⟩ => n == name) with + | some (.mk _ (.one n)) => pure n + | some v => throw s!"Field '{name}' must be a single integer, but got {repr v}." + | none => throw s!"Field '{name}' not found in record list {repr fields}." + +/- extract the arguments to the `dotGeneral` operation from a record in the list of attributes -/ +def extractDotDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile (List Nat × List Nat × List Nat × List Nat) := do + let dotAttr ← lookupAttributeValue attrs "dot_dimension_numbers" + match dotAttr with + | .stableHLORecord fields => + let lhs_batching_dims ← lookupNatsInFields fields "lhs_batching_dimensions" + let lhs_contracting_dims ← lookupNatsInFields fields "lhs_contracting_dimensions" + let rhs_batching_dims ← lookupNatsInFields fields "rhs_batching_dimensions" + let rhs_contracting_dims ← lookupNatsInFields fields "rhs_contracting_dimensions" + pure (lhs_batching_dims, lhs_contracting_dims, rhs_batching_dims, rhs_contracting_dims) + | _ => throw "Attribute 'dot_dimension_numbers' must be a stableHLORecord." + +/- extract the arguments to the `gather` operation from a record in the list of attributes -/ +def extractDimensionNumbers (attrs : List StableHLO.Parsing.Attribute) : Compile (List Nat × List Nat × List Nat × Nat) := do + let attr ← lookupAttributeValue attrs "dimension_numbers" + match attr with + | .stableHLORecord fields => + let offset_dims ← lookupNatsInFields fields "offset_dims" + let collapsed_slice_dims ← lookupNatsInFields fields "collapsed_slice_dims" + let start_index_map ← lookupNatsInFields fields "start_index_map" + let index_vector_dim ← lookupNatInFields fields "index_vector_dim" + pure (offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim) + | _ => throw "Attribute 'dimension_numbers' must be a stableHLORecord." + +/- +The StableHLO `reduce` operation always calls an arbitrary reduction function. +However, in TGR we only support a few specific reduction operations (mostly +arithmetic and logical binary operators). Since many StableHLO programs only +use these basic reduction operations, we can recognize when the StableHLO function +called by a `reduce` operation is one of these basic operations, and convert it +to the corresponding TGR BinaryOp. +If this process is unsuccessful, it means that the input `reduce` function is more +complicated and can't be supported by the current TGR design. +-/ +def reduceFunctionToReduceOp (f : StableHLO.Parsing.InputFunc) : Compile (BinaryOp) := do + match f with + | .mk _ [.stablehlo .maximum .., .return ..] => pure .max + | .mk _ [.stablehlo .add .., .return ..] => pure .add + | .mk _ [.stablehlo .and .., .return ..] => pure .and + | .mk _ [.stablehlo op .., .return ..] => throw s!"Unimplemented reduction function {repr op}." + | op => + throw ("Unable to recognize `reduce` function as simple binary operator. Compiling" ++ + "this program likely requires adding support for arbitrary function calling in `reduce`" + ++ s!"Function: {repr op}") + +/- +Compile a StableHLO operation into a list of TGR statements. + +Note: this function annotates each statement with the type of its output, +but this type is merely passed through from the HLO program, not computed anew. +This means it's possible that if there's a mistake in the shape calculation +in the HLO program, the TGR statements will also have incorrect shapes. +Eventually, we'll want a function that can shape-check an TGR program. +-/ +def compileOp (op : StableHLO.Parsing.Operation) : Compile (List Statement) := do + log s!"Compiling operation {repr op}" + match op with + | .stablehlo opCode inputValues inputFunctions inputAttributes outputs signature => do + /- Reuse the variable names and shapes from the StableHLO program -/ + let output ← match outputs with + | [output] => pure output + | _ => throw "Operator signature must have a single output." + let outputTy ← parseSingleTensorTypeFromValueTypes signature.range + /- helper function to emit TGR for element-wise unary ops -/ + let makeUnOp := fun (op : UnaryOp) => do + let a := inputValues[0]! + pure [.assign output (.unaryOp op a) outputTy] + /- helper function to emit TGR for element-wise binary ops -/ + let makeBinOp := fun (op : BinaryOp) => do + let a := inputValues[0]! + let b := inputValues[1]! + pure [.assign output (.binaryOp op a b) outputTy] + match opCode with + /- element-wise unary operators -/ + | .sqrt => makeUnOp .sqrt + | .negate => makeUnOp .neg + | .exponential => makeUnOp .exp + | .convert => makeUnOp (UnaryOp.convert outputTy.dtype) + /- element-wise binary operators -/ + | .compare => makeBinOp .cmp + | .multiply => makeBinOp .mul + | .add => makeBinOp .add + | .and => makeBinOp .and + | .maximum => makeBinOp .max + | .subtract => makeBinOp .sub + | .divide => makeBinOp .div + /- tensor nullary operators -/ + | .constant => do + let valueAttr ← lookupAttributeValue inputAttributes "value" + match valueAttr with + | (.tensor t) => do + let t ← parseTensorLiteral t + if t.shape == Shape.empty then + /- If the tensor is a scalar-array, we use a `full` operation + to create a tensor of the same shape as the output. -/ + pure [.assign output (.full t outputTy.shape) outputTy] + else + /- If the tensor is not a scalar-array, it corresponds to a + `const` operation. -/ + if t.shape.count != outputTy.shape.count then + throw s!"Tensor literal shape {t.shape} does not match expected output shape {outputTy.shape}." + let t ← t.reshape outputTy.shape + pure [.assign output (.const t) outputTy] + | _ => throw "Constant operation requires a 'value' attribute with tensor literal." + /- tensor unary operators -/ + | .reshape => do + let input := inputValues[0]! + pure [.assign output (.reshape input outputTy.shape) outputTy] + | .gather => + let (offsetDims, collapsedSliceDims, startIndexMap, indexVectorDim) ← + extractDimensionNumbers inputAttributes + let input := inputValues[0]! + let indices := inputValues[1]! + pure [.assign output (.gather input indices offsetDims collapsedSliceDims startIndexMap indexVectorDim) outputTy] + | .slice => + let input := inputValues[0]! + let start ← (← lookupAttributeValue inputAttributes "start_indices") |> parseArray + let limit ← (← lookupAttributeValue inputAttributes "limit_indices") |> parseArray + let stride ← (← lookupAttributeValue inputAttributes "strides") |> parseArray + let slices ← start + |> List.length + |> List.range + |> List.mapM (fun i => + TensorLib.Slice.make + (.some $ Int.ofNat start[i]!) + (.some $ Int.ofNat limit[i]!) + (.some $ Int.ofNat stride[i]!)) + pure [.assign output (.slice input slices) outputTy] + | .reduce => + let op ← reduceFunctionToReduceOp inputFunctions[0]! + let dims ← (← lookupAttributeValue inputAttributes "dimensions") |> parseArray + pure [.assign output (.reductionOp op inputValues[0]! inputValues[1]! dims) outputTy] -- TODO: init value + | .broadcastInDim => do + /- + We compile a broadcast by first reshaping to a new tensor `t` with added + dimensions of size 1 such that the rank is equal to the output rank, then + we broadcast `t` to the output shape which expands some of those + dimensions of size 1 + -/ + let input := inputValues[0]! + let inputTy := ← parseTensorTypeFromValueTypes signature.domain 0 + let broadcastDims ← (← lookupAttributeValue inputAttributes "broadcast_dimensions") |> parseArray + let reshaped ← gensym (input ++ "_reshaped") + /- A shape that has the same number of dimensions as the output, but where + specified dimensions match the input shape, and others are 1. -/ + let newShape := outputTy.shape.ndim |> List.range |> List.map (fun n => + if let .some i := broadcastDims.idxOf? n then + inputTy.shape.val[i]! + else + 1) + |> .mk + pure [ + .assign reshaped (.reshape input newShape) ⟨ newShape, inputTy.dtype ⟩, + .assign output (.broadcast reshaped outputTy.shape) outputTy + ] + | .transpose => do + let input := inputValues[0]! + let dims ← (← lookupAttributeValue inputAttributes "permutation") |> parseArray + pure [.assign output (.transpose input dims) outputTy] + /- tensor binary operators -/ + | .dotGeneral => do + /- + The semantics of the `dotGeneral` operation are complex, see the + specification for details. The variables here are named similar to the + variables in the spec to aid comprehension. + https://github.com/openxla/stablehlo/blob/6f7b4ab8f96dc65cf3c8e9824836117d2934cc45/docs/spec.md?#dot_general + -/ + /- Gather metadata from the inputs -/ + let (lhsBatchingDims, lhsContractingDims, rhsBatchingDims, rhsContractingDims) ← + extractDotDimensionNumbers inputAttributes + let lhs := inputValues[0]! + let rhs := inputValues[1]! + let lhsType ← parseTensorTypeFromValueTypes signature.domain 0 + let rhsType ← parseTensorTypeFromValueTypes signature.domain 1 + let dtype := lhsType.dtype + let lhsShape := lhsType.shape + let rhsShape := rhsType.shape + let lhsDims := List.range (TensorLib.Shape.ndim lhsShape) + let rhsDims := List.range (TensorLib.Shape.ndim rhsShape) + /- Calculate shapes of intermediate tensors and output -/ + let lhsResultDims := lhsDims.filter (fun i => !lhsBatchingDims.contains i && !lhsContractingDims.contains i) + let rhsResultDims := rhsDims.filter (fun i => !rhsBatchingDims.contains i && !rhsContractingDims.contains i) + let lhsTransposePerm := lhsBatchingDims ++ lhsResultDims ++ lhsContractingDims + let rhsTransposePerm := rhsBatchingDims ++ rhsResultDims ++ rhsContractingDims + let lhsTransposedShape := permute lhsShape.val lhsTransposePerm |>.get! + let rhsTransposedShape := permute rhsShape.val rhsTransposePerm |>.get! + let batchShape := lhsTransposedShape.take lhsBatchingDims.length + let lhsResultShape := lhsTransposedShape.drop lhsBatchingDims.length |>.take lhsResultDims.length + let rhsResultShape := rhsTransposedShape.drop rhsBatchingDims.length |>.take rhsResultDims.length + let contractingShape := lhsTransposedShape.drop (lhsBatchingDims.length + lhsResultDims.length) |> + List.take (lhsTransposedShape.length - (lhsBatchingDims.length + lhsResultDims.length)) + let batchSize := if batchShape.isEmpty then 1 else batchShape.foldl (· * ·) 1 + let resultShape := batchShape ++ lhsResultShape ++ rhsResultShape + let lhsResultSize := if lhsResultShape.isEmpty then 1 else lhsResultShape.foldl (· * ·) (1 : Nat) + let rhsResultSize := if rhsResultShape.isEmpty then 1 else rhsResultShape.foldl (· * ·) (1 : Nat) + let contractingSize := if contractingShape.isEmpty then 1 else contractingShape.foldl (· * ·) 1 + /- Create fresh variable names for intermediate results -/ + let lhsTransposedName ← gensym (lhs ++ "_transposed") + let rhsTransposedName ← gensym (rhs ++ "_transposed") + let lhsReshapedName ← gensym (lhs ++ "_reshaped") + let lhsReshapedShape := [batchSize, lhsResultSize, contractingSize] + let lhsReshapedTy := TensorTy.mk (.mk lhsReshapedShape) dtype + let rhsReshapedName ← gensym (rhs ++ "_reshaped") + let rhsReshapedShape := [batchSize, rhsResultSize, contractingSize] + let rhsReshapedTy := TensorTy.mk (.mk rhsReshapedShape) dtype + let resultReshapedName ← gensym (output ++ "_reshaped") + let resultReshapedShape := [batchSize, lhsResultSize, rhsResultSize] + let resultReshapedType := TensorTy.mk (.mk resultReshapedShape) dtype + /- Emit the TGR statements for the dotGeneral operation -/ + pure ([ + .comment "Dot General Operation", + .assign lhsTransposedName (.transpose lhs lhsTransposePerm) (.mk (.mk lhsTransposedShape) dtype), + .assign rhsTransposedName (.transpose rhs rhsTransposePerm) (.mk (.mk rhsTransposedShape) dtype), + .assign lhsReshapedName (.reshape lhsTransposedName lhsReshapedTy.shape) lhsReshapedTy, + .assign rhsReshapedName (.reshape rhsTransposedName rhsReshapedTy.shape) rhsReshapedTy, + .assign resultReshapedName (.batchMatmul lhsReshapedName rhsReshapedName) resultReshapedType, + .assign output (.reshape resultReshapedName (.mk resultShape)) outputTy, + ]) + | .concatenate => + let tensors := inputValues + let dim ← (← lookupAttributeValue inputAttributes "dimension") |> parseNatFromElementLiteral + pure [.assign output (.concat tensors dim) outputTy] + /- tensor ternary operators -/ + | .select => + let cond := inputValues[0]! + let a := inputValues[1]! + let b := inputValues[2]! + pure [.assign output (.select cond a b) outputTy] + | _ => throw s!"Unsupported HLO operation: {repr opCode}" + | .return ops _ => do + pure [Statement.ret ops] + | .call callee inputValues outputs signature => do + let output ← match outputs with + | [output] => pure output + | _ => throw "Call operator signature must have a single output." + pure [ + Statement.assign + output + (.call callee inputValues) + (← parseSingleTensorTypeFromValueTypes signature.range)] + + | s => throw s!"Unsupported operation type {repr s}" + +def compileFunc (f : StableHLO.Parsing.Function) : Compile Unit := do + let .mk args body := f.funcBody + let inputs ← args.mapM (fun ⟨ name, v ⟩ => do + match v with + | .tensorType t => parseTensorType t + | _ => throw s!"Function input {name} must have tensor type.") + let outputs ← f.funcType.range.mapM fun + | .tensorType t => parseTensorType t + | _ => throw "Function output must be a tensor type." + /- Since arguments are referred to by index, emit a statement for each + argument that assigns it to a named variable -/ + let preamble ← args.mapIdxM (fun i ⟨ name, v ⟩ => do + match v with + | .tensorType t => + pure (Statement.assign name (.arg i) (← parseTensorType t)) + | _ => throw s!"Function input {name} must have tensor type.") + let statements ← body.flatMapM compileOp + let func := Function.mk f.funcId inputs outputs (preamble ++ statements) + addFunction func + +def compileModule (m : StableHLO.Parsing.Module) : Compile Unit := + m.modFuncs.forM compileFunc + +def compile (m : List StableHLO.Parsing.Module) : (Except String Unit) × Ctx := + let compiled := match m with + | [m] => compileModule m + | _ => throw "Only one module is supported for now." + match compiled.run default with + | .ok _ s => (.ok (), s) + | .error err s => (throw err, s) + +end KLR.TGR.Compile diff --git a/KLR/TGR/Dot.lean b/KLR/TGR/Dot.lean new file mode 100644 index 00000000..83a0926e --- /dev/null +++ b/KLR/TGR/Dot.lean @@ -0,0 +1,123 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.TGR.AST +import SHerLOC.Analysis.Graph + +open StableHLO.Analysis (Graph Edge Vertex) + +/- This module provides a way to convert an TGR function into a DOT graph representation. -/ +namespace KLR.TGR.Graph + +/- +Process the name `var` so that it can used as a node ID in DOT format. +Notably, IDs can't start with a digit, so we prefix it with "node_". +-/ +def sanitize (var : String) : String := + s!"node_{var}" + +def makeReturnNode (funcName : String) : Vertex := + .mk + s!"return_{funcName}" + (.mk [ + ("label", s!"return\\n{funcName}"), + ("shape", "box"), + ("style", "filled"), + ("fillcolor", "lightgray"), + ("color", "gray") + ]) +def makeOpNode (op : Operator) (output : String) (ty : KLR.TGR.TensorTy): Vertex := + let attrs := match op with + | .arg .. => [ + ("shape", "diamond"), + ("style", "filled"), + ("fillcolor", "lightgreen"), + ("color", "green") + ] + | .batchMatmul .. => [ + ("style", "filled"), + ("fillcolor", "lightpink"), + ("color", "red") + ] + | .slice .. => [ + ("style", "filled"), + ("fillcolor", "lightblue"), + ("color", "blue") + ] + | _ => [] + .mk + (sanitize output) + (.mk ([ + ("label", s!"{opName op}\\n{output}\\n{ty.shape}"), + ] ++ attrs)) + +def makeConstNode (op : String) (name : String) (shape : TensorTy) (usedBy : String): Vertex := + .mk + s!"node_const_{name}_{usedBy}" + (.mk [ + ("label", s!"{op}\\n{name}\\n{shape.shape}"), + ("shape", "diamond"), + ("style", "filled"), + ("fillcolor", "lightyellow"), + ("color", "yellow") + ]) + +def makeEdge (source : String) (dest : String) : Edge := + .mk + source + dest + (.mk []) + +/- +Convert an TGR function to a DOT graph, where each variable is a vertex +and an edge exists from A to B if A is used in the computation of B. + +Note: since constants are reused in many parts of the function, they can +cause the graph to have long edges that cross over other nodes. To avoid this, +we create a separate vertex for each use of a constant. +-/ +def graph (f : TGR.Function) : Graph := Id.run do + let mut vertices := [] + let mut edges := [] + /- Every variables in the function that is the result of a `constant` operatior -/ + let mut consts := f.statements.filterMap (fun + | .assign v (.const ..) shape => .some ("const", v, shape) + | .assign v (.full ..) shape => .some ("full", v, shape) + | _ => .none) + /- A closure that creates edges from a list of inputs to an output variable. + If the input is a constant, it creates a vertex for that constant. -/ + let (makeEdges : List String → String → (List Vertex) × (List Edge)) := fun inputs output => Id.run do + let mut vertices := [] + let mut edges := [] + for input in inputs do + if let .some (op, v, shape) := consts.find? fun (_, v, _) => v == input then + let node := makeConstNode op v shape output + vertices := node :: vertices + edges := (makeEdge node.id output) :: edges + else + edges := (makeEdge (sanitize input) output) :: edges + return (vertices, edges) + + /- Walk the program statements and create vertices and edges. -/ + for s in f.statements do + match s with + | .assign _ (.const ..) _ | .assign _ (.full ..) _ => pure () + | .assign v op ty => + let deps := dependencies op + let (newVertices, newEdges) ← makeEdges deps (sanitize v) + vertices := [makeOpNode op v ty] ++ newVertices ++ vertices + edges := newEdges ++ edges + | .ret vars => + let node := makeReturnNode f.name + let deps := vars + let (newVertices, newEdges) ← makeEdges deps node.id + vertices := [node] ++ newVertices ++ vertices + edges := newEdges ++ edges + | .comment _ => pure () + + pure $ .mk f.name vertices edges + +end KLR.TGR.Graph diff --git a/KLR/TGR/Py.lean b/KLR/TGR/Py.lean new file mode 100644 index 00000000..d94aad5a --- /dev/null +++ b/KLR/TGR/Py.lean @@ -0,0 +1,154 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Biberstein +-/ + +import KLR.Core.Operators +import KLR.TGR.AST +import KLR.Util +import SHerLOC +import TensorLib.Shape +import TensorLib.Slice + +open Std.Format +open TensorLib (Dtype Shape Slice) + +/- +This module converts an TGR program into a runnable Python program. +At present, it can't convert TGR constants to python constants and can't +take input tensors, so it is only helpful to ensure that the shape +annotations are correct and that the program is well-formed. +-/ +namespace KLR.TGR.Py + +structure FormatCtx where + indent : Nat := 0 + program : String := "" + +abbrev Format := Std.Format + +def dTypeToPy : Dtype → String + | .bool => "np.bool_" + | .int8 => "np.int8" + | .int16 => "np.int16" + | .int32 => "np.int32" + | .int64 => "np.int64" + | .uint8 => "np.uint8" + | .uint16 => "np.uint16" + | .uint32 => "np.uint32" + | .uint64 => "np.uint64" + | .float32 => "np.float32" + | .float64 => "np.float64" + +def binOpToPy : BinaryOp → String + | .add => "np.add" + | .sub => "np.subtract" + | .mul => "np.multiply" + | .div => "np.divide" + | .and => "np.logical_and" + | .max => "np.maximum" + | .cmp => "np.compare" + +def unaryOpToPy : UnaryOp → String + | .exp => "np.exp" + | .sqrt => "np.sqrt" + | .neg => "np.negative" + | .convert d => s!"(lambda x: x.as_type({dTypeToPy d}))" + +def reduceOpToPy : BinaryOp → String + | .add => "np.sum" + | .max => "np.max" + | op => panic! s!"Unsupported reduction operation: {op}" + +def intLitToPy : StableHLO.Parsing.IntegerLiteral → String + | .mk .plus decimal => s!"{decimal}" + | .mk .minus decimal => s!"-{decimal}" + +def floatLitToPy : StableHLO.Parsing.FloatLiteral → String + | .decimal (.mk intPart fracPart sciPart) => + let intPartStr := intLitToPy intPart + let fracPartStr := intLitToPy fracPart + let sciPartStr := intLitToPy sciPart + s!"{intPartStr}.{fracPartStr}e{sciPartStr}" + | .hexaDecimal n => toString n + +def elementLitToPy : StableHLO.Parsing.ElementLiteral → String + | .booleanLiteral .true => "True" + | .booleanLiteral .false => "False" + | .floatLiteral f => floatLitToPy f + | .complexLiteral { real, imaginary } => + s!"complex({floatLitToPy real}, {floatLitToPy imaginary})" + | .stringLiteral str => s!"'{str}'" + +def valueToPy : StableHLO.Parsing.DenseLiteral → String + | .denseDimension n => s!"[{n.map valueToPy |> ", ".intercalate}]" + | .denseElements arr => ",".intercalate (arr.map elementLitToPy) + +def shapeToPy (s : Shape) : String := + s.val.map toString |> ",".intercalate + +def varToPy (arg : Var) : String := + /- Prefix, since Python variables can't start with a digit -/ + s!"var_{arg}" + +def opToPy (op : Operator) : String := + match op with + | .arg index => s!"args[{index}]" + | .binaryOp binOp a b => s!"{binOpToPy binOp}({varToPy a}, {varToPy b})" + | .unaryOp unOp a => s!"{unaryOpToPy unOp}({varToPy a})" + | .reductionOp redOp a b dim => s!"{reduceOpToPy redOp}({varToPy a}, initial={varToPy b}, axis={dim[0]!})" + | .batchMatmul a b => s!"np.einsum(\"bij,bkj->bik\", {varToPy a}, {varToPy b})" + | .arange start stop step shape => s!"np.arange({start}, {stop}, {step}).reshape({shapeToPy shape})" + | .concat tensors dim => + let tensorsStr := String.intercalate "," (tensors.map toString) + s!"np.concatenate([{tensorsStr}], axis={dim})" + | .select cond a b => s!"np.where({cond}, {varToPy a}, {varToPy b})" + | .full _value shape => s!"np.full(({shapeToPy shape}), 0)" -- TODO: make this use the actual value + | .transpose a dims => + let dimsStr := dims.map toString |> ", ".intercalate + s!"np.transpose({varToPy a}, axes=[{dimsStr}])" + | .reshape a shape => s!"{varToPy a}.reshape({shapeToPy shape})" + | .broadcast a shape => s!"np.broadcast_to({varToPy a}, ({shapeToPy shape}))" + | .const t => s!"np.random.random(({shapeToPy t.shape}))" -- TODO: make this use the actual constant value + | .gather .. => panic! s!"Gather operation not implemented in Python translation" + | .slice .. => panic! s!"Slice operation not implemented in Python translation" + | .call .. => + panic! s!"Can't translate call operators to Python" + +def compileStatement (s : Statement) : Format := + match s with + | .comment msg => text s!"# {msg}" + | .assign dest op {shape, ..} => + text (s!"{varToPy dest} = {opToPy op} # {shape}") ++ + if shape.ndim != 0 then + line ++ text (s!"assert {varToPy dest}.shape == ({shapeToPy shape},), \"Expected %s, got %s\" % (({shapeToPy shape}), {varToPy dest}.shape)") + else + nil + | .ret vars => + let varNames := vars.map varToPy |> ", ".intercalate + text s!"return {varNames}" + +def compileFunction (f : Function) : Format := + let inputsStr := f.inputs.map (fun {shape, ..} => s!"\"np.ndarray[({shapeToPy shape})]\"") |> ", ".intercalate |> fun x => s!"args: Tuple[{x}]" + let outputsStr := f.outputs.map (fun {shape, ..} => shapeToPy shape) |> ", ".intercalate + let funcHeader := s!"def {f.name}({inputsStr}) -> (\"np.ndarray[({outputsStr})]\"):" + let args := f.inputs.map (fun {shape, ..} => s!"np.random.random(({shapeToPy shape}))") + let funCall := s!"{f.name}([{", ".intercalate args}])" + text funcHeader ++ + (nest 4 (prefixJoin line (f.statements.map compileStatement))) ++ line ++ + text funCall + +def compileProgram (p : Program) : Format := + let lines := [ + text "import numpy as np", + text "import jax", + text "from typing import Tuple"] ++ + p.functions.map compileFunction + joinSep lines line + +/- Compile the TGR program to a Python program. -/ +def compile (p : Program) : String := + (compileProgram p).pretty + +end KLR.TGR.Py diff --git a/KLR/Util/Gzip/lake-manifest.json b/KLR/Util/Gzip/lake-manifest.json index b44f34e4..01308411 100644 --- a/KLR/Util/Gzip/lake-manifest.json +++ b/KLR/Util/Gzip/lake-manifest.json @@ -1,7 +1,17 @@ {"version": "1.1.0", "packagesDir": ".lake/packages", "packages": - [{"type": "path", + [{"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": false, + "configFile": "lakefile.lean"}, + {"type": "path", "scope": "", "name": "Util", "manifestFile": "lake-manifest.json", @@ -76,4 +86,4 @@ "inherited": true, "configFile": "lakefile.toml"}], "name": "Gzip", - "lakeDir": ".lake"} + "lakeDir": ".lake"} \ No newline at end of file diff --git a/KLR/Util/Hex.lean b/KLR/Util/Hex.lean index ea77d554..4136e248 100644 --- a/KLR/Util/Hex.lean +++ b/KLR/Util/Hex.lean @@ -40,6 +40,13 @@ private def hexCharToUInt8 (high : Char) (low : Char) : Option UInt8 := do let lowNibble ← hexCharToNibble low return (highNibble <<< 4) + lowNibble +def hexCharsToBitVecBE (c0 c1 c2 c3 c4 c5 c6 c7: Char) : Option (BitVec 32) := do + let b0 := (← hexCharToUInt8 c0 c1).toBitVec + let b1 := (← hexCharToUInt8 c2 c3).toBitVec + let b2 := (← hexCharToUInt8 c4 c5).toBitVec + let b3 := (← hexCharToUInt8 c6 c7).toBitVec + pure (b0 ++ b1 ++ b2 ++ b3) + def decode (s : String) : Option ByteArray := Id.run do let rec split : List Char -> List (Char × Char) | [] | [_] => [] diff --git a/KLR/lake-manifest.json b/KLR/lake-manifest.json index 65a9531e..49c699d0 100644 --- a/KLR/lake-manifest.json +++ b/KLR/lake-manifest.json @@ -1,7 +1,17 @@ {"version": "1.1.0", "packagesDir": ".lake/packages", "packages": - [{"url": "https://github.com/leanprover/TensorLib.git", + [{"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": false, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover/TensorLib.git", "type": "git", "subDir": null, "scope": "", @@ -62,4 +72,4 @@ "inherited": true, "configFile": "lakefile.toml"}], "name": "Util", - "lakeDir": ".lake"} + "lakeDir": ".lake"} \ No newline at end of file diff --git a/KLR/lakefile.lean b/KLR/lakefile.lean index 2150fff4..5fd33e39 100644 --- a/KLR/lakefile.lean +++ b/KLR/lakefile.lean @@ -28,5 +28,8 @@ require plausible from git require TensorLib from git "https://github.com/leanprover/TensorLib.git" @ "v0.0.12" +require SHerLOC from git + "https://github.com/leanprover/SHerLOC.git" @ "c74ae090d4326cca9ff636184c330a67ca039ef6" + -- Comment the above and uncomment this for local development -- require TensorLib from "../../TensorLib" diff --git a/Main.lean b/Main.lean index 9898e56e..2813d5f6 100644 --- a/Main.lean +++ b/Main.lean @@ -18,6 +18,9 @@ import Cli import KLR import TensorLib.Npy import TensorLib.Tensor +import SHerLOC +import KLR.TGR +import SHerLOC.Analysis.Graph open Cli open KLR @@ -298,6 +301,32 @@ def evalKLR (p : Parsed) : IO UInt32 := do TensorLib.Npy.Ndarray.save! arr filename return 0 +def hloToTGR (p : Parsed) : IO UInt32 := do + let file := p.positionalArg! "file" |>.as! String + let s <- IO.FS.readFile file + match StableHLO.Parsing.parse s with + | .ok (hlo, _) => + let tgr := KLR.TGR.Compile.compile hlo + match tgr with + | (.ok _, s) => do + let tgr := s.program + IO.println (toString tgr) + let headFunction := s.program.functions.head! + -- print graph of function + let g := KLR.TGR.Graph.graph headFunction |> toString + writeContent "dot" p g + -- print TGR program as Python program + let py := KLR.TGR.Py.compile tgr + writeContent "py" p py + return 0 + | (.error e, s) => do + IO.eprintln s!"Error compiling HLO to TGR: {e}" + IO.eprintln s!"{repr s}" + return 1 + | .error e => + IO.eprintln e + return 1 + -- -- Command configuration def gatherCmd := `[Cli| @@ -388,6 +417,15 @@ def evalKLRCmd := `[Cli| kernelFunctionName : String; "Name of the kernel function" ...inputFiles : String; ".npy files corresponding to the inputs to the kernel, in positional order" ] +def hloToTGRCmd := `[Cli| + "hlo-to-tgr" VIA hloToTGR; + "Compile HLO graph to TGR graph" + + FLAGS: + o, outfile : String; "Name of output file" + ARGS: + file : String; "File of HLO graph in .mlir format" +] def klrCmd : Cmd := `[Cli| klr NOOP; ["0.0.12"] @@ -400,7 +438,8 @@ def klrCmd : Cmd := `[Cli| infoCmd; nkiToKLRCmd; traceCmd; - typecheckCmd + typecheckCmd; + hloToTGRCmd ] def main (args : List String) : IO UInt32 := do diff --git a/lake-manifest.json b/lake-manifest.json index ccb3050a..2c28de0d 100644 --- a/lake-manifest.json +++ b/lake-manifest.json @@ -8,6 +8,16 @@ "inherited": false, "dir": "KLR", "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover/SHerLOC.git", + "type": "git", + "subDir": null, + "scope": "", + "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "name": "SHerLOC", + "manifestFile": "lake-manifest.json", + "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", + "inherited": false, + "configFile": "lakefile.lean"}, {"url": "https://github.com/leanprover/TensorLib.git", "type": "git", "subDir": null, diff --git a/lakefile.lean b/lakefile.lean index 34142f2e..d5441db5 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -58,6 +58,9 @@ require plausible from git require TensorLib from git "https://github.com/leanprover/TensorLib.git" @ "v0.0.13" +require SHerLOC from git + "https://github.com/leanprover/SHerLOC.git" @ "c74ae090d4326cca9ff636184c330a67ca039ef6" + -- Comment the above and uncomment this for local development -- require TensorLib from "../TensorLib"