From dd3e2091ae38f419fc682e720c962cd6cc073641 Mon Sep 17 00:00:00 2001 From: Kirsten <32720576+kirstenmg@users.noreply.github.com> Date: Thu, 23 May 2024 14:14:52 -0700 Subject: [PATCH 1/2] Make passthrough work with state edges for terminating loops --- dag_in_context/src/lib.rs | 1 + .../src/loop_iteration_analysis.egg | 76 +++++++++++++++++++ .../src/optimizations/loop_unroll.egg | 68 +---------------- .../src/optimizations/passthrough.egg | 13 +++- 4 files changed, 90 insertions(+), 68 deletions(-) create mode 100644 dag_in_context/src/loop_iteration_analysis.egg diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index 4aabd2259..95472e7c0 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -57,6 +57,7 @@ pub fn prologue() -> String { include_str!("utility/expr_size.egg"), include_str!("utility/drop_at.egg"), include_str!("interval_analysis.egg"), + include_str!("loop_iteration_analysis.egg"), include_str!("optimizations/switch_rewrites.egg"), include_str!("optimizations/select.egg"), include_str!("optimizations/peepholes.egg"), diff --git a/dag_in_context/src/loop_iteration_analysis.egg b/dag_in_context/src/loop_iteration_analysis.egg new file mode 100644 index 000000000..4483304ca --- /dev/null +++ b/dag_in_context/src/loop_iteration_analysis.egg @@ -0,0 +1,76 @@ +;; Analysis to get the number of iterations of a loop +(ruleset loop-iter-analysis) + +;; inputs, outputs -> number of iterations +;; The minimum possible guess is 1 because of do-while loops +(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new))) + +;; Marks loops that we know will terminate +(relation TerminatingLoop (Expr Expr)) + +;; by default, guess that all loops run 1000 times +(rule ((DoWhile inputs outputs)) + ((set (LoopNumItersGuess inputs outputs) 1000)) + :ruleset loop-iter-analysis) + +;; For a loop that is false, its num iters is 1 +(rule + ((= loop (DoWhile inputs outputs)) + (= (Const (Bool false) ty ctx) (Get outputs 0))) + ((set (LoopNumItersGuess inputs outputs) 1) + (TerminatingLoop inputs outputs)) +:ruleset loop-iter-analysis) + +;; Figure out number of iterations for a loop with constant bounds and initial value +;; and i is updated before checking pred +;; TODO: we could make it work for decrementing loops +(rule + ((= lhs (DoWhile inputs outputs)) + (= num-inputs (tuple-length inputs)) + (= pred (Get outputs 0)) + ;; iteration counter starts at start_const + (= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i)) + ;; updated counter at counter_i + (= next_counter (Get outputs (+ counter_i 1))) + ;; increments by some constant each loop + (= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i) + (Const (Int increment) _ty2 _ctx2))) + (> increment 0) + ;; while next_counter less than end_constant + (= pred (Bop (LessThan) next_counter + (Const (Int end_constant) _ty3 _ctx3))) + ;; end constant is at least start constant + (>= end_constant start_const) + ) + ( + (set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment)) + (TerminatingLoop inputs outputs) + ) + :ruleset loop-iter-analysis) + +;; Figure out number of iterations for a loop with constant bounds and initial value +;; and i is updated after checking pred +(rule + ((= lhs (DoWhile inputs outputs)) + (= num-inputs (tuple-length inputs)) + (= pred (Get outputs 0)) + ;; iteration counter starts at start_const + (= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i)) + ;; updated counter at counter_i + (= next_counter (Get outputs (+ counter_i 1))) + ;; increments by a constant each loop + (= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i) + (Const (Int increment) _ty2 _ctx2))) + (> increment 0) + ;; while this counter less than end_constant + (= pred (Bop (LessThan) (Get (Arg _ty _ctx) counter_i) + (Const (Int end_constant) _ty3 _ctx3))) + ;; end constant is at least start constant + (>= end_constant start_const) + ) + ( + (set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1)) + (TerminatingLoop inputs outputs) + ) + :ruleset loop-iter-analysis) + diff --git a/dag_in_context/src/optimizations/loop_unroll.egg b/dag_in_context/src/optimizations/loop_unroll.egg index 3120568d9..a2cd4756c 100644 --- a/dag_in_context/src/optimizations/loop_unroll.egg +++ b/dag_in_context/src/optimizations/loop_unroll.egg @@ -1,75 +1,9 @@ ;; Some simple simplifications of loops +;; Depends on loop iteration analysis (ruleset loop-unroll) (ruleset loop-peel) (ruleset loop-iters-analysis) -;; inputs, outputs -> number of iterations -;; The minimum possible guess is 1 because of do-while loops -;; TODO: dead loop deletion can turn loops with a false condition to a body -(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new))) - -;; by default, guess that all loops run 1000 times -(rule ((DoWhile inputs outputs)) - ((set (LoopNumItersGuess inputs outputs) 1000)) - :ruleset loop-iters-analysis) - -;; For a loop that is false, its num iters is 1 -(rule - ((= loop (DoWhile inputs outputs)) - (= (Const (Bool false) ty ctx) (Get outputs 0))) - ((set (LoopNumItersGuess inputs outputs) 1)) -:ruleset loop-iters-analysis) - -;; Figure out number of iterations for a loop with constant bounds and initial value -;; and i is updated before checking pred -;; TODO: we could make it work for decrementing loops -(rule - ((= lhs (DoWhile inputs outputs)) - (= pred (Get outputs 0)) - ;; iteration counter starts at start_const - (= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i)) - ;; updated counter at counter_i - (= next_counter (Get outputs (+ counter_i 1))) - ;; increments by some constant each loop - (= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i) - (Const (Int increment) _ty2 _ctx2))) - (> increment 0) - ;; while next_counter less than end_constant - (= pred (Bop (LessThan) next_counter - (Const (Int end_constant) _ty3 _ctx3))) - ;; end constant is at least start constant - (>= end_constant start_const) - ) - ( - (set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment)) - ) - :ruleset loop-iters-analysis) - -;; Figure out number of iterations for a loop with constant bounds and initial value -;; and i is updated after checking pred -(rule - ((= lhs (DoWhile inputs outputs)) - (= pred (Get outputs 0)) - ;; iteration counter starts at start_const - (= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i)) - (= body-arg (Get (Arg _ty _ctx) counter_i)) - ;; updated counter at counter_i - (= next_counter (Get outputs (+ counter_i 1))) - ;; increments by a constant each loop - (= next_counter (Bop (Add) body-arg - (Const (Int increment) _ty2 _ctx2))) - (> increment 0) - ;; while this counter less than end_constant - (= pred (Bop (LessThan) body-arg - (Const (Int end_constant) _ty3 _ctx3))) - ;; end constant is at least start constant - (>= end_constant start_const) - ) - ( - (set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1)) - ) - :ruleset loop-iters-analysis) - ;; loop peeling rule ;; Only peel loops that we know iterate < 3 times (function LoopPeeledPlaceholder (Expr) Assumption :unextractable) diff --git a/dag_in_context/src/optimizations/passthrough.egg b/dag_in_context/src/optimizations/passthrough.egg index c420723d0..1820df085 100644 --- a/dag_in_context/src/optimizations/passthrough.egg +++ b/dag_in_context/src/optimizations/passthrough.egg @@ -1,7 +1,8 @@ +;; Relies on loop iteration analysis (ruleset passthrough) -;; Pass through thetas +;; Pass through thetas: pure case (rule ((= lhs (Get loop i)) (= loop (DoWhile inputs pred-outputs)) (= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i)) @@ -13,6 +14,16 @@ ((union lhs (Get inputs i))) :ruleset passthrough) +;; Pass through thetas: state edge case +(rule ((= lhs (Get loop i)) + (= loop (DoWhile inputs pred-outputs)) + (= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i)) + ;; It is OK to pass through state edges as long as the loop terminates + (TerminatingLoop inputs pred-outputs) + ) + ((union lhs (Get inputs i))) + :ruleset passthrough) + ;; Pass through switch arguments (rule ((= lhs (Get switch i)) (= switch (Switch pred inputs branches)) From a20d8e01e06b88eac935ff587b86d61703b15ebb Mon Sep 17 00:00:00 2001 From: Kirsten <32720576+kirstenmg@users.noreply.github.com> Date: Thu, 30 May 2024 12:50:10 -0700 Subject: [PATCH 2/2] Make loop passthrough work; performance issues though --- .../src/loop_iteration_analysis.egg | 11 +++---- .../src/optimizations/loop_unroll.egg | 1 - .../src/optimizations/passthrough.egg | 29 +++++++++++++++---- dag_in_context/src/utility/util.egg | 4 +++ tests/passing/small/dead_loop_deletion.bril | 15 ++++++++++ .../small/loop_state_pass_through.bril | 14 +++++++++ 6 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 tests/passing/small/dead_loop_deletion.bril create mode 100644 tests/passing/small/loop_state_pass_through.bril diff --git a/dag_in_context/src/loop_iteration_analysis.egg b/dag_in_context/src/loop_iteration_analysis.egg index 4483304ca..cfb324f55 100644 --- a/dag_in_context/src/loop_iteration_analysis.egg +++ b/dag_in_context/src/loop_iteration_analysis.egg @@ -1,5 +1,5 @@ ;; Analysis to get the number of iterations of a loop -(ruleset loop-iter-analysis) +(ruleset loop-iters-analysis) ;; inputs, outputs -> number of iterations ;; The minimum possible guess is 1 because of do-while loops @@ -11,7 +11,7 @@ ;; by default, guess that all loops run 1000 times (rule ((DoWhile inputs outputs)) ((set (LoopNumItersGuess inputs outputs) 1000)) - :ruleset loop-iter-analysis) + :ruleset loop-iters-analysis) ;; For a loop that is false, its num iters is 1 (rule @@ -19,7 +19,7 @@ (= (Const (Bool false) ty ctx) (Get outputs 0))) ((set (LoopNumItersGuess inputs outputs) 1) (TerminatingLoop inputs outputs)) -:ruleset loop-iter-analysis) +:ruleset loop-iters-analysis) ;; Figure out number of iterations for a loop with constant bounds and initial value ;; and i is updated before checking pred @@ -33,6 +33,7 @@ ;; updated counter at counter_i (= next_counter (Get outputs (+ counter_i 1))) ;; increments by some constant each loop + ;; TODO: how to handle the invariant case? (= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i) (Const (Int increment) _ty2 _ctx2))) (> increment 0) @@ -46,7 +47,7 @@ (set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment)) (TerminatingLoop inputs outputs) ) - :ruleset loop-iter-analysis) + :ruleset loop-iters-analysis) ;; Figure out number of iterations for a loop with constant bounds and initial value ;; and i is updated after checking pred @@ -72,5 +73,5 @@ (set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1)) (TerminatingLoop inputs outputs) ) - :ruleset loop-iter-analysis) + :ruleset loop-iters-analysis) diff --git a/dag_in_context/src/optimizations/loop_unroll.egg b/dag_in_context/src/optimizations/loop_unroll.egg index a2cd4756c..8f64abd6a 100644 --- a/dag_in_context/src/optimizations/loop_unroll.egg +++ b/dag_in_context/src/optimizations/loop_unroll.egg @@ -2,7 +2,6 @@ ;; Depends on loop iteration analysis (ruleset loop-unroll) (ruleset loop-peel) -(ruleset loop-iters-analysis) ;; loop peeling rule ;; Only peel loops that we know iterate < 3 times diff --git a/dag_in_context/src/optimizations/passthrough.egg b/dag_in_context/src/optimizations/passthrough.egg index 1820df085..9aac01107 100644 --- a/dag_in_context/src/optimizations/passthrough.egg +++ b/dag_in_context/src/optimizations/passthrough.egg @@ -14,14 +14,31 @@ ((union lhs (Get inputs i))) :ruleset passthrough) -;; Pass through thetas: state edge case -(rule ((= lhs (Get loop i)) - (= loop (DoWhile inputs pred-outputs)) +; ;; Pass through thetas: state edge case +(rule ((= loop (DoWhile inputs pred-outputs)) (= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i)) ;; It is OK to pass through state edges as long as the loop terminates - (TerminatingLoop inputs pred-outputs) - ) - ((union lhs (Get inputs i))) + (TerminatingLoop inputs pred-outputs)) + ( + ;; To maintain the linearity invariant, we must remove the state edge + ;; from the loop. + (let new-inputs (TupleRemoveAt inputs i)) + (let removed-outputs (TupleRemoveAt pred-outputs (+ i 1))) + (let new-outputs (DropAt (TmpCtx) i removed-outputs)) + + (let projected-old-loop (TupleRemoveAt loop i)) + (let new-loop (DoWhile new-inputs new-outputs)) + (union new-loop projected-old-loop) + + ;; Resolve the temporary context + (union (TmpCtx) (InLoop new-inputs new-outputs)) + (delete (TmpCtx)) + + ;; State edge can be gotten without the loop now + (union (Get loop i) (Get inputs i)) + + ;; Subsume the loop later + (ToSubsumeLoop inputs pred-outputs)) :ruleset passthrough) ;; Pass through switch arguments diff --git a/dag_in_context/src/utility/util.egg b/dag_in_context/src/utility/util.egg index 711b04614..def2c2191 100644 --- a/dag_in_context/src/utility/util.egg +++ b/dag_in_context/src/utility/util.egg @@ -75,4 +75,8 @@ ((subsume (If a b c d))) :ruleset subsume-after-helpers) +(relation ToSubsumeLoop (Expr Expr)) +(rule ((ToSubsumeLoop in p-out)) + ((subsume (DoWhile in p-out))) + :ruleset subsume-after-helpers) diff --git a/tests/passing/small/dead_loop_deletion.bril b/tests/passing/small/dead_loop_deletion.bril new file mode 100644 index 000000000..004b76140 --- /dev/null +++ b/tests/passing/small/dead_loop_deletion.bril @@ -0,0 +1,15 @@ +@main: int { + i: int = const 1; + forty: int = const 40; + one: int = const 1; + +.loop_body: + i: int = add i one; + cond: bool = lt i forty; + br cond .loop_body .loop_end; + +.loop_end: + j: int = const 2; + + ret j; +} diff --git a/tests/passing/small/loop_state_pass_through.bril b/tests/passing/small/loop_state_pass_through.bril new file mode 100644 index 000000000..4caa0a96d --- /dev/null +++ b/tests/passing/small/loop_state_pass_through.bril @@ -0,0 +1,14 @@ +# ARGS: 5 +@main(input: int) { + one: int = const 1; + i: int = const 1; + jmp .loop; +.loop: + max: int = const 10; + cond: bool = lt i max; + i: int = add i one; + br cond .loop .exit; +.exit: + res: int = add i input; + print res; +}