Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2286c5e
feat: LoRA finetuning framework for BitNet (Alpaca-style instruction …
junjihashimoto Mar 31, 2026
84e5044
feat: verified backward pass specs with numerical gradient checks
junjihashimoto Mar 31, 2026
b1a5339
feat: verified AD framework + attention backward kernels
junjihashimoto Mar 31, 2026
35f0d2f
docs: verified AD guide + wrong backward detection test
junjihashimoto Mar 31, 2026
dbdbc4e
docs: comprehensive LoRA finetuning development guide
junjihashimoto Mar 31, 2026
9cfa90f
feat: multi-layer backward with per-layer saved activations
junjihashimoto Mar 31, 2026
bb7f8e9
feat: PyTorch-standard training infrastructure
junjihashimoto Mar 31, 2026
8faa331
fix: float literal precision in WGSL codegen (fixes AdamW NaN)
junjihashimoto Mar 31, 2026
c0b1516
feat: proper float parser + LSpec tests for parseFloat and floatToWGSL
junjihashimoto Mar 31, 2026
aaec444
test: add alpaca_facts.json test dataset for LoRA verification
junjihashimoto Mar 31, 2026
ca50eb4
feat: proper softmax backward + per-layer attnBuf saving + --max-grad…
junjihashimoto Apr 1, 2026
2bb98f6
feat: SafeBuffer module + RMSNorm backward + NaN-safe save
junjihashimoto Apr 1, 2026
614eb58
wip: O projection backward (BitLinear transpose) + SafeBuffer improve…
junjihashimoto Apr 1, 2026
08760b3
fix: BitLinear transpose kernel (O projection backward) with correct …
junjihashimoto Apr 1, 2026
d7d1feb
docs: backward completeness plan — root cause analysis + type-safe ch…
junjihashimoto Apr 1, 2026
be51c58
feat: residual backward accumulation + savedActivation diagnosis
junjihashimoto Apr 1, 2026
3e4dcbf
fix: use GPU kernel for zeroBuffer (prevent batch/writeBuffer conflict)
junjihashimoto Apr 1, 2026
f951599
fix: remove incorrect dInput accumulation in residual backward
junjihashimoto Apr 1, 2026
58919f1
fix: floatArrayToBytes Float64→Float32 conversion + RMSNorm backward …
junjihashimoto Apr 1, 2026
f686c07
feat: enable RMSNorm backward in full attention backward chain
junjihashimoto Apr 2, 2026
46002a6
feat: Final RMSNorm backward (LM head → last layer gradient)
junjihashimoto Apr 2, 2026
7bf3522
feat: PyTorch-matching defaults (lr=2e-4, clip=1.0, warmup=6%)
junjihashimoto Apr 2, 2026
786f882
feat: type-safe backward chain (DiffChain) + completeness test
junjihashimoto Apr 2, 2026
64fdb99
test: attention backward chain complete (7/7 ops, dimension check PASS)
junjihashimoto Apr 2, 2026
8111e4e
feat: FFN backward complete — full transformer backward chain
junjihashimoto Apr 2, 2026
2b8c779
test: GPU vs CPU backward consistency — all 4 kernels PASS
junjihashimoto Apr 2, 2026
9d4410c
feat: type-safe BackwardOps registry with compile-time completeness g…
junjihashimoto Apr 2, 2026
3c03fea
feat: kernel fusion framework design + BitLinear transpose optimization
junjihashimoto Apr 2, 2026
76ca099
feat: kernel fusion framework + fused LoRA forward (B@h+add)
junjihashimoto Apr 2, 2026
b8cb078
feat: Flash Attention with equivalence proof
junjihashimoto Apr 2, 2026
e5a949e
feat: Flash Attention GPU kernel + equivalence tests
junjihashimoto Apr 2, 2026
2e7480a
fix: Flash Attention Exp.var snapshot bug + GPU test PASS (error=0.0)
junjihashimoto Apr 2, 2026
a48a7a8
fix: Flash Attention dynamic cacheLen + WGSL uniformity workaround
junjihashimoto Apr 2, 2026
a340692
feat: Flash Attention production integration (3 kernels → 2)
junjihashimoto Apr 2, 2026
dd6a5f1
perf: revert to standard attention path (faster for short context)
junjihashimoto Apr 2, 2026
ee6a409
feat: Tiled Flash Attention v2 — 34 TPS (up from 30 TPS v1)
junjihashimoto Apr 2, 2026
71ffcec
perf: eliminate copy dispatch + pre-allocate partial buffer (35.6 TPS)
junjihashimoto Apr 2, 2026
e2eb52c
feat: in-place Flash Attention kernel + standard path for production
junjihashimoto Apr 2, 2026
07c0e45
feat: Flash Attention production — 40 TPS (up from 37 TPS standard)
junjihashimoto Apr 4, 2026
e8e1d5d
docs: move CHANGELOG to docs/ + add STATUS.md with plan and remaining…
junjihashimoto Apr 5, 2026
3828d0f
docs: update README with Flash Attention, LoRA finetuning, verified AD
junjihashimoto Apr 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions Examples/BitNetComplete.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import Hesper.Inference.Sampling
import Hesper.Tokenizer.SentencePiece
import Hesper.GGUF.Reader
import Hesper.Logging
import Hesper.LoRA.Types
import Hesper.LoRA.IO
import Hesper.LoRA.Inference

/-!
# Complete BitNet Text Generation
Expand Down Expand Up @@ -57,20 +60,28 @@ def loadModel (ggufPath : String) : IO (BitNetModel × Tokenizer × Device × Op

return (model, tokenizer, device, tokenizer.vocab.eosToken)

/-- Find value for a flag like "--lora path" in args list -/
def findFlag (args : List String) (flag : String) : Option String :=
match args with
| [] => none
| [_] => none
| a :: b :: rest => if a == flag then some b else findFlag (b :: rest) flag

/-- Run single-shot generation -/
def runGeneration (args : List String) : IO Unit := do
if args.length < 2 then
IO.println "Usage: bitnet-complete <gguf_model> <prompt> [max_tokens] [--stats] [--verbose]"
IO.println " bitnet-complete <gguf_model> --interactive"
IO.println " bitnet-complete <gguf_model> -i"
IO.println "Usage: bitnet-complete <gguf_model> <prompt> [max_tokens] [--stats] [--verbose] [--lora <path>]"
IO.println " bitnet-complete <gguf_model> --interactive [--lora <path>]"
IO.println " bitnet-complete <gguf_model> -i [--lora <path>]"
return

let ggufPath := args[0]!
let promptText := args[1]!
let showStats := args.any (· == "--stats")
let verbose := args.any (· == "--verbose")
let loraPath := findFlag args "--lora"
-- Filter out flags before parsing max_tokens
let positionalArgs := args.filter (fun a => !a.startsWith "--")
let positionalArgs := args.filter (fun a => !a.startsWith "--" && a != (loraPath.getD ""))
let maxTokens := if positionalArgs.length >= 3 then positionalArgs[2]!.toNat! else 20

-- Disable verbose by default for clean output
Expand All @@ -80,6 +91,9 @@ def runGeneration (args : List String) : IO Unit := do
IO.println " BitNet Text Generation"
IO.println "═══════════════════════════════════════════════"
IO.println s!"Model: {ggufPath}"
match loraPath with
| some p => IO.println s!"LoRA: {p}"
| none => pure ()
IO.println s!"Prompt: \"{promptText}\""
IO.println s!"Max tokens: {maxTokens}"
IO.println ""
Expand All @@ -90,7 +104,14 @@ def runGeneration (args : List String) : IO Unit := do
IO.println s!"Prompt tokens ({promptTokens.size}): {promptTokens}"
IO.println ""

let outputTokens ← generate device model promptTokens maxTokens .Greedy eosToken showStats
let outputTokens ← match loraPath with
| some p =>
-- Load LoRA adapter and generate with it
let adapter ← Hesper.LoRA.IO.loadAdapter device p model.config.dim model.config.kvDim
let loraState ← Hesper.LoRA.Inference.createLoRAInferenceState device adapter model.config.dim model.config.kvDim
Hesper.LoRA.Inference.generateWithLoRA device model adapter loraState promptTokens maxTokens .Greedy eosToken
| none =>
generate device model promptTokens maxTokens .Greedy eosToken showStats

let outputText := decode tokenizer outputTokens
IO.println ""
Expand All @@ -99,15 +120,26 @@ def runGeneration (args : List String) : IO Unit := do
IO.println "─────────────────────────────────────────"

/-- Run interactive REPL -/
def runInteractive (ggufPath : String) : IO Unit := do
def runInteractive (ggufPath : String) (loraPath : Option String := none) : IO Unit := do
IO.println "═══════════════════════════════════════════════"
IO.println " BitNet Interactive Mode"
IO.println "═══════════════════════════════════════════════"
IO.println s!"Model: {ggufPath}"
match loraPath with
| some p => IO.println s!"LoRA: {p}"
| none => pure ()
IO.println ""

let (model, tokenizer, device, eosToken) ← loadModel ggufPath

-- Load LoRA if specified
let loraOpt ← match loraPath with
| some p =>
let adapter ← Hesper.LoRA.IO.loadAdapter device p model.config.dim model.config.kvDim
let loraState ← Hesper.LoRA.Inference.createLoRAInferenceState device adapter model.config.dim model.config.kvDim
pure (some (adapter, loraState))
| none => pure none

-- Disable verbose logging for clean interactive output
setVerbose false

Expand Down Expand Up @@ -167,7 +199,11 @@ def runInteractive (ggufPath : String) : IO Unit := do
let promptTokens := encode tokenizer input
IO.println s!"[{promptTokens.size} tokens] Generating..."

let outputTokens ← generate device model promptTokens maxTokens .Greedy eosToken
let outputTokens ← match loraOpt with
| some (adapter, loraState) =>
Hesper.LoRA.Inference.generateWithLoRA device model adapter loraState promptTokens maxTokens .Greedy eosToken
| none =>
generate device model promptTokens maxTokens .Greedy eosToken

let newTokenCount := outputTokens.size - promptTokens.size
let outputText := decode tokenizer outputTokens
Expand All @@ -189,7 +225,8 @@ def main (args : List String) : IO Unit := do
return

let arg1 := args[1]!
let loraPath := findFlag args "--lora"
if arg1 == "--interactive" || arg1 == "-i" then
runInteractive args[0]!
runInteractive args[0]! loraPath
else
runGeneration args
Loading
Loading