diff --git a/EML/.gitignore b/EML/.gitignore index 0e5f1e4b..a7940129 100644 --- a/EML/.gitignore +++ b/EML/.gitignore @@ -1,3 +1,11 @@ /_build /_coverage +*.o +*.s +*.S +*.ll +*.exe +*.a +*.so +*.dll diff --git a/EML/lib/backend/llvm_ir/analysis.ml b/EML/lib/backend/llvm_ir/analysis.ml index 450b23ac..f06018bf 100644 --- a/EML/lib/backend/llvm_ir/analysis.ml +++ b/EML/lib/backend/llvm_ir/analysis.ml @@ -43,6 +43,44 @@ let analyze (program : anf_program) = | AnfEval _ -> None) program in + let is_valid_linker_ident name = + String.length name > 0 + && String.for_all + (fun c -> + (c >= 'a' && c <= 'z') + || (c >= 'A' && c <= 'Z') + || (c >= '0' && c <= '9') + || Char.equal c '_') + name + in + let mangle_operator_for_linker name = + "op_" + ^ Base.String.concat_map name ~f:(function + | '*' -> "_star" + | '+' -> "_plus" + | '-' -> "_minus" + | '/' -> "_slash" + | '=' -> "_eq" + | '<' -> "_lt" + | '>' -> "_gt" + | '!' -> "_bang" + | '&' -> "_amp" + | '|' -> "_bar" + | '^' -> "_hat" + | '@' -> "_at" + | '~' -> "_tilde" + | '?' -> "_q" + | '.' -> "_dot" + | ':' -> "_colon" + | '%' -> "_percent" + | '$' -> "_dollar" + | c + when (c >= 'a' && c <= 'z') + || (c >= 'A' && c <= 'Z') + || (c >= '0' && c <= '9') + || Char.equal c '_' -> String.make 1 c + | c -> "_u" ^ Int.to_string (Char.code c)) + in let mangle_reserved name = if is_reserved name then "eml_" ^ name @@ -50,10 +88,18 @@ let analyze (program : anf_program) = then "eml_start" else name in + let asm_name_for_func func_name = + let base = + if is_valid_linker_ident func_name + then func_name + else mangle_operator_for_linker func_name + in + mangle_reserved base + in let functions, _ = List.fold_left (fun (reversed_functions, counts) (func_name, _arity, params, body, is_rec) -> - let base_asm_name = mangle_reserved func_name in + let base_asm_name = asm_name_for_func func_name in let duplicate_index = Base.Map.find counts func_name |> Option.value ~default:0 in diff --git a/EML/lib/backend/llvm_ir/generator.ml b/EML/lib/backend/llvm_ir/generator.ml index b274ac57..c0cf15ca 100644 --- a/EML/lib/backend/llvm_ir/generator.ml +++ b/EML/lib/backend/llvm_ir/generator.ml @@ -293,6 +293,7 @@ let gen_binop_native op left_v right_v = let* r = untag_bool_val right_v in let* v = with_optional_value (or_ builder l r "or") in tag_bool_result v + | Custom _ -> fail "Custom operator must be compiled to application" ;; let gen_unop_native op tagged_val = diff --git a/EML/lib/backend/ricsv/analysis.ml b/EML/lib/backend/ricsv/analysis.ml index cb11905d..32fa27d7 100644 --- a/EML/lib/backend/ricsv/analysis.ml +++ b/EML/lib/backend/ricsv/analysis.ml @@ -171,6 +171,44 @@ let analyze (program : anf_program) = | AnfEval _ -> None) program in + let is_valid_linker_ident name = + String.length name > 0 + && String.for_all + (fun c -> + (c >= 'a' && c <= 'z') + || (c >= 'A' && c <= 'Z') + || (c >= '0' && c <= '9') + || Char.equal c '_') + name + in + let mangle_operator_for_linker name = + "op_" + ^ Base.String.concat_map name ~f:(function + | '*' -> "_star" + | '+' -> "_plus" + | '-' -> "_minus" + | '/' -> "_slash" + | '=' -> "_eq" + | '<' -> "_lt" + | '>' -> "_gt" + | '!' -> "_bang" + | '&' -> "_amp" + | '|' -> "_bar" + | '^' -> "_hat" + | '@' -> "_at" + | '~' -> "_tilde" + | '?' -> "_q" + | '.' -> "_dot" + | ':' -> "_colon" + | '%' -> "_percent" + | '$' -> "_dollar" + | c + when (c >= 'a' && c <= 'z') + || (c >= 'A' && c <= 'Z') + || (c >= '0' && c <= '9') + || Char.equal c '_' -> String.make 1 c + | c -> "_u" ^ Int.to_string (Char.code c)) + in let mangle_reserved name = if is_reserved name then "eml_" ^ name @@ -178,6 +216,14 @@ let analyze (program : anf_program) = then "eml_start" else name in + let asm_name_for_func func_name = + let base = + if is_valid_linker_ident func_name + then func_name + else mangle_operator_for_linker func_name + in + mangle_reserved base + in let functions, _ = List.fold_left (fun (reversed_functions, generated_name_counts) @@ -189,7 +235,7 @@ let analyze (program : anf_program) = , slots_count , max_stack_args , max_create_tuple_array_bytes ) -> - let base_asm_name = mangle_reserved func_name in + let base_asm_name = asm_name_for_func func_name in let duplicate_index = Base.Map.find generated_name_counts func_name |> Option.value ~default:0 in diff --git a/EML/lib/backend/ricsv/architecture.ml b/EML/lib/backend/ricsv/architecture.ml index e4638586..65a54334 100644 --- a/EML/lib/backend/ricsv/architecture.ml +++ b/EML/lib/backend/ricsv/architecture.ml @@ -39,6 +39,7 @@ module Riscv_backend = struct | Xor of reg * reg * reg (* xor двух регистров: rd = rs1 ^ rs2 *) | Mul of reg * reg * reg (* умножение: rd = rs1 * rs2 *) | Div of reg * reg * reg (* целочисленное деление: rd = rs1 / rs2 *) + | Slli of reg * reg * int (* логический сдвиг влево на константу: rd = rs << imm *) | Srli of reg * reg * int (* логический сдвиг вправо на константу: rd = rs >>> imm *) let pp_reg ppf = function @@ -59,6 +60,7 @@ module Riscv_backend = struct | Sub (rd, rs1, rs2) -> fprintf ppf "sub %a, %a, %a" pp_reg rd pp_reg rs1 pp_reg rs2 | Mul (rd, rs1, rs2) -> fprintf ppf "mul %a, %a, %a" pp_reg rd pp_reg rs1 pp_reg rs2 | Div (rd, rs1, rs2) -> fprintf ppf "div %a, %a, %a" pp_reg rd pp_reg rs1 pp_reg rs2 + | Slli (rd, rs1, imm) -> fprintf ppf "slli %a, %a, %d" pp_reg rd pp_reg rs1 imm | Srli (rd, rs1, imm) -> fprintf ppf "srli %a, %a, %d" pp_reg rd pp_reg rs1 imm | Xori (rd, rs1, imm) -> fprintf ppf "xori %a, %a, %d" pp_reg rd pp_reg rs1 imm | Xor (rd, rs1, rs2) -> fprintf ppf "xor %a, %a, %a" pp_reg rd pp_reg rs1 pp_reg rs2 @@ -113,6 +115,7 @@ module Riscv_backend = struct let xor rd rs1 rs2 = [ Xor (rd, rs1, rs2) ] let mul rd rs1 rs2 = [ Mul (rd, rs1, rs2) ] let div rd rs1 rs2 = [ Div (rd, rs1, rs2) ] + let slli rd rs imm = [ Slli (rd, rs, imm) ] let srli rd rs imm = [ Srli (rd, rs, imm) ] let add_tag_items dst delta = [ Addi (dst, dst, delta) ] let arg_regs = [ a0; a1; a2; a3; a4; a5; a6; a7 ] diff --git a/EML/lib/backend/ricsv/architecture.mli b/EML/lib/backend/ricsv/architecture.mli index b3eef301..70113bd5 100644 --- a/EML/lib/backend/ricsv/architecture.mli +++ b/EML/lib/backend/ricsv/architecture.mli @@ -35,6 +35,7 @@ module Riscv_backend : sig | Xor of reg * reg * reg | Mul of reg * reg * reg | Div of reg * reg * reg + | Slli of reg * reg * int | Srli of reg * reg * int val pp_reg : Format.formatter -> reg -> unit @@ -76,6 +77,7 @@ module Riscv_backend : sig val xor : reg -> reg -> reg -> instr list val mul : reg -> reg -> reg -> instr list val div : reg -> reg -> reg -> instr list + val slli : reg -> reg -> int -> instr list val srli : reg -> reg -> int -> instr list val add_tag_items : reg -> int -> instr list val arg_regs : reg list diff --git a/EML/lib/backend/ricsv/generator.ml b/EML/lib/backend/ricsv/generator.ml index c4f78645..17c1f9ab 100644 --- a/EML/lib/backend/ricsv/generator.ml +++ b/EML/lib/backend/ricsv/generator.ml @@ -420,6 +420,8 @@ and gen_cexpr dst = function | ComplexImmediate imm -> gen_imm dst imm | ComplexUnarOper (Negative, op) -> gen_neg dst op | ComplexUnarOper (Not, op) -> gen_not dst op + | ComplexBinOper (Custom _, _, _) -> + fail "Custom operator must be compiled to application" | ComplexBinOper (op, left, right) -> gen_binop dst op left right | ComplexBranch (cond, then_e, else_e) -> gen_branch dst cond then_e else_e | ComplexField (tuple_imm, idx) -> gen_field dst tuple_imm idx @@ -453,15 +455,19 @@ let bind_param_to_stack env i = function | _ -> fail "unsupported pattern" ;; -let flush_instr_buffer ppf = +let flush_instr_buffer ~enable_peephole ppf = let* state = get in let instruction_buffer = state.instr_buffer in let* () = put { state with instr_buffer = [] } in - let () = List.iter (fun item -> format_item ppf item) (List.rev instruction_buffer) in + let instructions = List.rev instruction_buffer in + let instructions = + if enable_peephole then Peephole.optimize instructions else instructions + in + let () = List.iter (fun item -> format_item ppf item) instructions in return () ;; -let gen_func ~enable_gc asm_name params body frame_sz ppf = +let gen_func ~enable_gc ~enable_peephole asm_name params body frame_sz ppf = fprintf ppf "\n .globl %s\n .type %s, @function\n" asm_name asm_name; let args = List.length params in let params_reg, params_stack = @@ -486,11 +492,11 @@ let gen_func ~enable_gc asm_name params body frame_sz ppf = let* () = spill_params_to_frame params_reg in let* () = gen_anf result_reg body in let* () = append (epilogue ~enable_gc ~is_main:(String.equal asm_name "main")) in - let* () = flush_instr_buffer ppf in + let* () = flush_instr_buffer ~enable_peephole ppf in return () ;; -let gen_program ~enable_gc ppf (analysis : analysis_result) = +let gen_program ~enable_gc ~enable_peephole ppf (analysis : analysis_result) = fprintf ppf ".section .text"; let base = Runtime.Primitives.runtime_primitive_arities in let arity_map = @@ -516,7 +522,7 @@ let gen_program ~enable_gc ppf (analysis : analysis_result) = let* () = modify (fun state -> { state with current_func_index = function_index }) in - gen_func ~enable_gc fn.asm_name fn.params fn.body frame_sz ppf) + gen_func ~enable_gc ~enable_peephole fn.asm_name fn.params fn.body frame_sz ppf) in match run comp init with | Ok ((), _) -> diff --git a/EML/lib/backend/ricsv/generator.mli b/EML/lib/backend/ricsv/generator.mli index 5d1bfb22..b49ab9d9 100644 --- a/EML/lib/backend/ricsv/generator.mli +++ b/EML/lib/backend/ricsv/generator.mli @@ -4,6 +4,7 @@ val gen_program : enable_gc:bool + -> enable_peephole:bool -> Format.formatter -> Analysis.analysis_result -> (unit, string) Result.t diff --git a/EML/lib/backend/ricsv/peephole.ml b/EML/lib/backend/ricsv/peephole.ml new file mode 100644 index 00000000..c9d9e155 --- /dev/null +++ b/EML/lib/backend/ricsv/peephole.ml @@ -0,0 +1,439 @@ +(** Copyright 2025-2026, Victoria Ostrovskaya & Danil Usoltsev *) + +(** SPDX-License-Identifier: LGPL-3.0-or-later *) + +open Architecture +open Riscv_backend + +let is_small_addi_imm imm = imm >= -2048 && imm <= 2047 + +let write_reg = function + | Addi (rd, _, _) + | Ld (rd, _) + | Mv (rd, _) + | Li (rd, _) + | Add (rd, _, _) + | Sub (rd, _, _) + | La (rd, _) + | Slt (rd, _, _) + | Seqz (rd, _) + | Snez (rd, _) + | Xori (rd, _, _) + | Xor (rd, _, _) + | Mul (rd, _, _) + | Div (rd, _, _) + | Slli (rd, _, _) + | Srli (rd, _, _) -> Some rd + | Sd _ | Call _ | Ret | Beq _ | J _ | Label _ -> None +;; + +let reads_reg instruction reg = + let same register = equal_reg register reg in + match instruction with + | Addi (_, rs, _) + | Mv (_, rs) + | Seqz (_, rs) + | Snez (_, rs) + | Xori (_, rs, _) + | Srli (_, rs, _) + | Slli (_, rs, _) -> same rs + | Sd (rs, (base, _)) -> same rs || same base + | Ld (_, (base, _)) -> same base + | Add (_, rs1, rs2) + | Sub (_, rs1, rs2) + | Slt (_, rs1, rs2) + | Xor (_, rs1, rs2) + | Mul (_, rs1, rs2) + | Div (_, rs1, rs2) -> same rs1 || same rs2 + | Beq (rs1, rs2, _) -> same rs1 || same rs2 + | Li _ | Call _ | Ret | J _ | Label _ | La _ -> false +;; + +let reg_used_later reg instructions = + List.exists (fun instruction -> reads_reg instruction reg) instructions +;; + +let is_power_of_two positive_value = + positive_value > 0 && positive_value land (positive_value - 1) = 0 +;; + +let log2_power_of_two n = + let rec loop power value = + if value = 1 then power else loop (power + 1) (value lsr 1) + in + loop 0 n +;; + +let replace_reg from_register to_register instruction = + let replace register = + if equal_reg register from_register then to_register else register + in + match instruction with + | Add (rd, rs1, rs2) -> Add (rd, replace rs1, replace rs2) + | Sub (rd, rs1, rs2) -> Sub (rd, replace rs1, replace rs2) + | Mul (rd, rs1, rs2) -> Mul (rd, replace rs1, replace rs2) + | Div (rd, rs1, rs2) -> Div (rd, replace rs1, replace rs2) + | Xor (rd, rs1, rs2) -> Xor (rd, replace rs1, replace rs2) + | Slt (rd, rs1, rs2) -> Slt (rd, replace rs1, replace rs2) + | Beq (rs1, rs2, label) -> Beq (replace rs1, replace rs2, label) + | Addi (rd, rs, imm) -> Addi (rd, replace rs, imm) + | Xori (rd, rs, imm) -> Xori (rd, replace rs, imm) + | Srli (rd, rs, imm) -> Srli (rd, replace rs, imm) + | Slli (rd, rs, imm) -> Slli (rd, replace rs, imm) + | Seqz (rd, rs) -> Seqz (rd, replace rs) + | Snez (rd, rs) -> Snez (rd, replace rs) + | Sd (rs, (base, offset)) -> Sd (replace rs, (replace base, offset)) + | Ld _ | Mv _ | Li _ | Call _ | Ret | J _ | Label _ | La _ -> instruction +;; + +let simplify_single = function + | Mv (rd, rs) when equal_reg rd rs -> None + | Addi (rd, rs, 0) when equal_reg rd rs -> None + | Addi (rd, rs, 0) -> Some (Mv (rd, rs)) + | Xori (rd, rs, 0) when equal_reg rd rs -> None + | Xori (rd, rs, 0) -> Some (Mv (rd, rs)) + | Add (rd, rs, Zero) when equal_reg rd rs -> None + | Add (rd, Zero, rs) when equal_reg rd rs -> None + | Add (rd, rs, Zero) -> Some (Mv (rd, rs)) + | Add (rd, Zero, rs) -> Some (Mv (rd, rs)) + | Sub (rd, rs, Zero) when equal_reg rd rs -> None + | Sub (rd, rs, Zero) -> Some (Mv (rd, rs)) + | instruction -> Some instruction +;; + +let simplify_pair first second rest = + match first, second with + | Mv (target_register, source_register), next_instruction + when reads_reg next_instruction target_register -> + let safe_to_drop_mv = + match write_reg next_instruction with + | Some written_register when equal_reg written_register target_register -> true + | _ -> not (reg_used_later target_register rest) + in + if safe_to_drop_mv + then Some [ replace_reg target_register source_register next_instruction ] + else None + | ( Li (constant_register, constant_value) + , Add (destination_register, left_register, right_register) ) + when is_small_addi_imm constant_value -> + if equal_reg right_register constant_register + then Some [ Addi (destination_register, left_register, constant_value) ] + else if equal_reg left_register constant_register + then Some [ Addi (destination_register, right_register, constant_value) ] + else None + | ( Li (constant_register, constant_value) + , Mul (destination_register, left_register, right_register) ) + when is_power_of_two constant_value -> + let shift_amount = log2_power_of_two constant_value in + if equal_reg right_register constant_register + then Some [ Slli (destination_register, left_register, shift_amount) ] + else if equal_reg left_register constant_register + then Some [ Slli (destination_register, right_register, shift_amount) ] + else None + | J l1, Label l2 when String.equal l1 l2 -> Some [] + | Sd (_, (base1, offset1)), Sd (rs2, (base2, offset2)) + when equal_reg base1 base2 && offset1 = offset2 -> Some [ Sd (rs2, (base2, offset2)) ] + | Sd (stored_reg, (base1, offset1)), Ld (loaded_reg, (base2, offset2)) + when equal_reg base1 base2 && offset1 = offset2 -> + if equal_reg stored_reg loaded_reg + then Some [ first ] + else Some [ first; Mv (loaded_reg, stored_reg) ] + | Ld (rd1, (base1, offset1)), Ld (rd2, (base2, offset2)) + when equal_reg base1 base2 && offset1 = offset2 -> + if equal_reg rd1 rd2 then Some [ first ] else Some [ first; Mv (rd2, rd1) ] + | Addi (rd1, rs1, imm1), Addi (rd2, rs2, imm2) + when equal_reg rd1 rs1 && equal_reg rd2 rs2 && equal_reg rd1 rd2 -> + let merged = imm1 + imm2 in + if is_small_addi_imm merged + then if merged = 0 then Some [] else Some [ Addi (rd1, rs1, merged) ] + else None + | _ -> + (match write_reg first, write_reg second with + | Some written_first, Some written_second + when equal_reg written_first written_second && not (reads_reg second written_first) + -> Some [ second ] + | _ -> None) +;; + +let simplify_triple first second third rest = + match first, second, third with + | Li (left_register, left_const), Li (right_register, right_const), Beq (rs1, rs2, label) + when equal_reg rs1 left_register && equal_reg rs2 right_register -> + if left_const = right_const + then + Some [ Li (left_register, left_const); Li (right_register, right_const); J label ] + else Some [ Li (left_register, left_const); Li (right_register, right_const) ] + | Mv (target_register, source_register), middle_instruction, next_instruction + when reads_reg next_instruction target_register + && (not (reads_reg middle_instruction target_register)) + && + match write_reg middle_instruction with + | Some written_register -> not (equal_reg written_register target_register) + | None -> true -> + let safe_to_drop_mv = + match write_reg next_instruction with + | Some written_register when equal_reg written_register target_register -> true + | _ -> not (reg_used_later target_register rest) + in + if safe_to_drop_mv + then + Some + [ middle_instruction + ; replace_reg target_register source_register next_instruction + ] + else None + | Mv (first_target, first_source), Mv (second_target, second_source), Add (dst, rs1, rs2) + when equal_reg first_source second_source + && equal_reg rs1 first_target + && equal_reg rs2 second_target -> Some [ Add (dst, first_source, first_source) ] + | Mv (first_target, first_source), Mv (second_target, second_source), Mul (dst, rs1, rs2) + when equal_reg first_source second_source + && equal_reg rs1 first_target + && equal_reg rs2 second_target -> Some [ Mul (dst, first_source, first_source) ] + | Mv (first_target, first_source), Mv (second_target, second_source), Sub (dst, rs1, rs2) + when equal_reg first_source second_source + && equal_reg rs1 first_target + && equal_reg rs2 second_target -> Some [ Sub (dst, first_source, first_source) ] + | Mv (first_target, first_source), Mv (second_target, second_source), Div (dst, rs1, rs2) + when equal_reg first_source second_source + && equal_reg rs1 first_target + && equal_reg rs2 second_target -> Some [ Div (dst, first_source, first_source) ] + | Mv (first_target, first_source), Mv (second_target, second_source), Xor (dst, rs1, rs2) + when equal_reg first_source second_source + && equal_reg rs1 first_target + && equal_reg rs2 second_target -> Some [ Xor (dst, first_source, first_source) ] + | Mv (first_target, first_source), Mv (second_target, second_source), Slt (dst, rs1, rs2) + when equal_reg first_source second_source + && equal_reg rs1 first_target + && equal_reg rs2 second_target -> Some [ Slt (dst, first_source, first_source) ] + | _ -> None +;; + +let one_pass instructions = + let rec loop changed acc = function + | first :: second :: third :: rest -> + (match simplify_triple first second third rest with + | Some rewritten -> + let rewritten_reversed = List.rev_append rewritten acc in + loop true rewritten_reversed rest + | None -> + (match simplify_pair first second (third :: rest) with + | Some rewritten -> + let rewritten_reversed = List.rev_append rewritten acc in + loop true rewritten_reversed (third :: rest) + | None -> + (match simplify_single first with + | None -> loop true acc (second :: third :: rest) + | Some simplified -> + loop changed (simplified :: acc) (second :: third :: rest)))) + | first :: second :: rest -> + (match simplify_pair first second rest with + | Some rewritten -> + let rewritten_reversed = List.rev_append rewritten acc in + loop true rewritten_reversed rest + | None -> + (match simplify_single first with + | None -> loop true acc (second :: rest) + | Some simplified -> loop changed (simplified :: acc) (second :: rest))) + | [ last ] -> + (match simplify_single last with + | None -> List.rev acc, true + | Some simplified -> List.rev (simplified :: acc), changed) + | [] -> List.rev acc, changed + in + loop false [] instructions +;; + +let same_memory_key (base1, offset1) (base2, offset2) = + equal_reg base1 base2 && offset1 = offset2 +;; + +let find_cached_load key cache = + List.find_map + (fun (cached_key, cached_register) -> + if same_memory_key cached_key key then Some cached_register else None) + cache +;; + +let remove_cached_key key cache = + List.filter (fun (cached_key, _) -> not (same_memory_key cached_key key)) cache +;; + +let invalidate_register register cache = + List.filter + (fun ((base, _), cached_register) -> + not (equal_reg cached_register register || equal_reg base register)) + cache +;; + +let track_load_cache instructions = + let rec loop changed cache acc = function + | [] -> List.rev acc, changed + | instruction :: rest -> + (match instruction with + | Ld (destination_register, key) -> + (match find_cached_load key cache with + | Some cached_register when equal_reg cached_register destination_register -> + loop true cache acc rest + | Some cached_register -> + let cache_without_destination = + invalidate_register destination_register cache + in + let next_cache = (key, destination_register) :: cache_without_destination in + loop true next_cache (Mv (destination_register, cached_register) :: acc) rest + | None -> + let cache_without_destination = + invalidate_register destination_register cache + in + let next_cache = (key, destination_register) :: cache_without_destination in + loop changed next_cache (instruction :: acc) rest) + | Sd (stored_register, key) -> + let next_cache = + remove_cached_key key cache + |> fun cache_without_key -> (key, stored_register) :: cache_without_key + in + loop changed next_cache (instruction :: acc) rest + | Call _ | Ret | Beq _ | J _ | Label _ -> loop changed [] (instruction :: acc) rest + | _ -> + let next_cache = + match write_reg instruction with + | Some written_register -> invalidate_register written_register cache + | None -> cache + in + loop changed next_cache (instruction :: acc) rest) + in + loop false [] [] instructions +;; + +let reads_slot (slot_base, slot_offset) = function + | Ld (_, (base, offset)) -> equal_reg base slot_base && offset = slot_offset + | _ -> false +;; + +let stores_slot (slot_base, slot_offset) = function + | Sd (_, (base, offset)) -> equal_reg base slot_base && offset = slot_offset + | _ -> false +;; + +let writes_slot_base (slot_base, _) instruction = + match write_reg instruction with + | Some written_register -> equal_reg written_register slot_base + | None -> false +;; + +let can_prove_store_dead ~allow_drop_at_block_end slot following_instructions = + let rec walk = function + | [] -> allow_drop_at_block_end + | instruction :: rest -> + (match + ( reads_slot slot instruction + , stores_slot slot instruction + , writes_slot_base slot instruction ) + with + | true, _, _ -> false + | _, true, _ -> true + | _, _, true -> false + | _ -> walk rest) + in + walk following_instructions +;; + +let eliminate_dead_stores_in_block ~allow_drop_at_block_end block = + let rec loop changed acc = function + | [] -> List.rev acc, changed + | Sd (_, slot) :: rest when can_prove_store_dead ~allow_drop_at_block_end slot rest -> + loop true acc rest + | (Sd (_, _) as store_instruction) :: rest -> + loop changed (store_instruction :: acc) rest + | instruction :: rest -> loop changed (instruction :: acc) rest + in + loop false [] block +;; + +let eliminate_local_dead_stores instructions = + let is_barrier = function + | Call _ | Ret | Beq _ | J _ | Label _ -> true + | _ -> false + in + let rec split_non_barrier acc = function + | instruction :: rest when not (is_barrier instruction) -> + split_non_barrier (instruction :: acc) rest + | remaining -> List.rev acc, remaining + in + let rec process changed acc = function + | [] -> List.rev acc, changed + | instructions -> + let block, remaining = split_non_barrier [] instructions in + let allow_drop_at_block_end = + match remaining with + | Ret :: _ -> true + | _ -> false + in + let optimized_block, block_changed = + eliminate_dead_stores_in_block ~allow_drop_at_block_end block + in + (match remaining with + | barrier :: tail -> + process + (changed || block_changed) + (barrier :: List.rev_append optimized_block acc) + tail + | [] -> List.rev (List.rev_append optimized_block acc), changed || block_changed) + in + process false [] instructions +;; + +let find_redundant_restore_store loaded_register loaded_slot following_instructions = + let is_barrier = function + | Call _ | Ret | Beq _ | J _ | Label _ -> true + | _ -> false + in + let rec search prefix = function + | [] -> List.rev_append prefix [], false + | instruction :: rest + when is_barrier instruction || writes_slot_base loaded_slot instruction -> + List.rev_append prefix (instruction :: rest), false + | Sd (stored_register, store_slot) :: rest + when same_memory_key loaded_slot store_slot + && equal_reg stored_register loaded_register -> + List.rev_append prefix rest, true + | Sd (stored_register, store_slot) :: rest when same_memory_key loaded_slot store_slot + -> List.rev_append prefix (Sd (stored_register, store_slot) :: rest), false + | instruction :: rest -> + (match write_reg instruction with + | Some written_register when equal_reg written_register loaded_register -> + List.rev_append prefix (instruction :: rest), false + | _ -> search (instruction :: prefix) rest) + in + search [] following_instructions +;; + +let eliminate_redundant_restore_stores instructions = + let rec loop changed acc = function + | (Ld (loaded_register, loaded_slot) as load_instruction) :: rest -> + let new_rest, removed = + find_redundant_restore_store loaded_register loaded_slot rest + in + loop (changed || removed) (load_instruction :: acc) new_rest + | instruction :: rest -> loop changed (instruction :: acc) rest + | [] -> List.rev acc, changed + in + loop false [] instructions +;; + +let optimize instructions = + let rec fixed_point current = + let after_local, changed_local = one_pass current in + let after_load_cache, changed_cache = track_load_cache after_local in + let after_redundant_store, changed_redundant_store = + eliminate_redundant_restore_stores after_load_cache + in + let after_dead_store, changed_dead_store = + eliminate_local_dead_stores after_redundant_store + in + if changed_local || changed_cache || changed_redundant_store || changed_dead_store + then fixed_point after_dead_store + else after_dead_store + in + fixed_point instructions +;; diff --git a/EML/lib/backend/ricsv/peephole.mli b/EML/lib/backend/ricsv/peephole.mli new file mode 100644 index 00000000..11751d43 --- /dev/null +++ b/EML/lib/backend/ricsv/peephole.mli @@ -0,0 +1,8 @@ +(** Copyright 2025-2026, Victoria Ostrovskaya & Danil Usoltsev *) + +(** SPDX-License-Identifier: LGPL-3.0-or-later *) + +open Architecture +open Riscv_backend + +val optimize : instr list -> instr list diff --git a/EML/lib/backend/ricsv/runner.ml b/EML/lib/backend/ricsv/runner.ml index 9efd4272..a6cb52b3 100644 --- a/EML/lib/backend/ricsv/runner.ml +++ b/EML/lib/backend/ricsv/runner.ml @@ -5,6 +5,7 @@ open Middleend.Anf open Analysis -let gen_program ?(enable_gc = false) ppf (program : anf_program) = - program |> analyze |> Generator.gen_program ~enable_gc ppf +let gen_program ?(enable_gc = false) ?(enable_peephole = true) ppf (program : anf_program) + = + program |> analyze |> Generator.gen_program ~enable_gc ~enable_peephole ppf ;; diff --git a/EML/lib/backend/ricsv/runner.mli b/EML/lib/backend/ricsv/runner.mli index e5df8f59..c6ef9f26 100644 --- a/EML/lib/backend/ricsv/runner.mli +++ b/EML/lib/backend/ricsv/runner.mli @@ -4,6 +4,7 @@ val gen_program : ?enable_gc:bool + -> ?enable_peephole:bool -> Format.formatter -> Middleend.Anf.anf_program -> (unit, string) result diff --git a/EML/lib/frontend/ast.ml b/EML/lib/frontend/ast.ml index 18b2e913..0ca8309e 100644 --- a/EML/lib/frontend/ast.ml +++ b/EML/lib/frontend/ast.ml @@ -24,6 +24,7 @@ type bin_oper = | LowerThan (* [<] *) | Equal (* [=] *) | NotEqual (* [<>] *) + | Custom of string (* user-defined: ( ** ), ( @@ ), etc. *) [@@deriving show { with_path = false }] type unar_oper = @@ -89,6 +90,39 @@ type structure = type program = structure list [@@deriving show { with_path = false }] let bin_op_list = [ "*"; "/"; "+"; "-"; "^"; ">="; "<="; "<>"; "="; ">"; "<"; "&&"; "||" ] +let builtin_op_list = [ "*"; "/"; "+"; "-"; ">="; "<="; "<>"; "="; ">"; "<"; "&&"; "||" ] + +let builtin_op_of_string = function + | "*" -> Some Multiply + | "/" -> Some Division + | "+" -> Some Plus + | "-" -> Some Minus + | ">=" -> Some GreatestEqual + | "<=" -> Some LowestEqual + | "<>" -> Some NotEqual + | "=" -> Some Equal + | ">" -> Some GreaterThan + | "<" -> Some LowerThan + | "&&" -> Some And + | "||" -> Some Or + | _ -> None +;; + +let builtin_op_to_string = function + | Multiply -> "*" + | Division -> "/" + | Plus -> "+" + | Minus -> "-" + | GreatestEqual -> ">=" + | LowestEqual -> "<=" + | NotEqual -> "<>" + | Equal -> "=" + | GreaterThan -> ">" + | LowerThan -> "<" + | And -> "&&" + | Or -> "||" + | Custom s -> s +;; let rec pp_ty fmt = function | TyPrim x -> fprintf fmt "%s" x diff --git a/EML/lib/frontend/parser.ml b/EML/lib/frontend/parser.ml index 3f871894..73dcd33c 100644 --- a/EML/lib/frontend/parser.ml +++ b/EML/lib/frontend/parser.ml @@ -40,6 +40,86 @@ let is_digit = function | _ -> false ;; +let is_operator_char = function + | '!' + | '$' + | '%' + | '&' + | '*' + | '+' + | '-' + | '.' + | '/' + | ':' + | '<' + | '=' + | '>' + | '?' + | '@' + | '^' + | '|' + | '~' -> true + | _ -> false +;; + +let is_operator_char_infix = function + | '!' + | '$' + | '%' + | '&' + | '*' + | '+' + | '-' + | '.' + | '/' + | '<' + | '=' + | '>' + | '?' + | '@' + | '^' + | '~' -> true + | '|' | ':' | _ -> false +;; + +let is_custom_power_op op = + String.length op >= 2 && String.equal (String.sub op ~pos:0 ~len:2) "**" +;; + +let first_char op = String.get op 0 + +let is_custom_mul_op op = + (not (is_custom_power_op op)) + && + match first_char op with + | '*' | '/' | '%' -> true + | _ -> false +;; + +let is_custom_add_op op = + match first_char op with + | '+' | '-' -> true + | _ -> false +;; + +let is_custom_concat_op op = + match first_char op with + | '@' | '^' -> true + | _ -> false +;; + +let is_custom_cmp_op op = + match first_char op with + | '=' | '<' | '>' | '|' | '&' | '$' -> true + | _ -> false +;; + +let is_custom_lowest_op op = + match first_char op with + | '!' | '?' | '~' | '.' -> true + | _ -> false +;; + let white_space = take_while Char.is_whitespace let token s = white_space *> string s let token1 s = white_space *> s @@ -120,10 +200,7 @@ let parse_ident = >>= fun s -> if is_keyword s then fail "It is not identifier" else return s in let parse_op_ident = - white_space - *> char '(' - *> white_space - *> choice (List.map Ast.bin_op_list ~f:(fun opr -> token opr *> return opr)) + white_space *> char '(' *> white_space *> take_while1 is_operator_char <* white_space <* char ')' in @@ -255,8 +332,8 @@ let parse_pattern = parse_pattern_tuple lst <|> lst) ;; -let parse_left_associative expr oper = - let rec go acc = lift2 (fun f x -> f acc x) oper expr >>= go <|> return acc in +let parse_left_associative expr oper right_expr = + let rec go acc = lift2 (fun f x -> f acc x) oper right_expr >>= go <|> return acc in expr >>= go ;; @@ -264,6 +341,15 @@ let parse_expr_bin_oper parse_bin_op tkn = token tkn *> return (fun e1 e2 -> ExpBinOper (parse_bin_op, e1, e2)) ;; +let parse_right_associative expr oper = + let rec parse () = + expr + >>= fun left -> + oper >>= (fun combine -> parse () >>| fun right -> combine left right) <|> return left + in + parse () +;; + let multiply = parse_expr_bin_oper Multiply "*" let division = parse_expr_bin_oper Division "/" let plus = parse_expr_bin_oper Plus "+" @@ -282,6 +368,17 @@ let compare = let and_op = parse_expr_bin_oper And "&&" let or_op = parse_expr_bin_oper Or "||" + +let parse_custom_infix_when pred = + white_space *> take_while1 is_operator_char_infix + >>= fun op -> + if Option.is_some (builtin_op_of_string op) + then fail "builtin" + else if pred op + then return (fun e1 e2 -> ExpBinOper (Custom op, e1, e2)) + else fail "custom_op_mismatch" +;; + let parse_expr_ident = parse_ident >>| fun x -> ExpIdent x let parse_expr_const = parse_const >>| fun c -> ExpConst c @@ -309,8 +406,8 @@ let parse_expr_list parse_expr = (fun (fst_exp, snd_exp, exp_list) -> ExpTuple (fst_exp, snd_exp, exp_list)) ;; -let parse_expr_apply e = - parse_left_associative e (return (fun e1 e2 -> ExpApply (e1, e2))) +let parse_expr_apply e right = + parse_left_associative e (return (fun e1 e2 -> ExpApply (e1, e2))) right ;; let parse_expr_lambda parse_expr = @@ -370,6 +467,7 @@ let parse_expr_sequence parse_expr = parse_left_associative parse_expr (token ";" *> return (fun exp1 exp2 -> ExpLet (NonRec, (PatUnit, exp1), [], exp2))) + parse_expr ;; let parse_expr_construct parse_expr = @@ -480,13 +578,35 @@ let parse_top_expr parse_expr = ] ;; -let parse_exp_apply e = - let app = parse_expr_apply e in +let parse_exp_apply e right = + let app = parse_expr_apply e right in let app = parse_expr_unar_oper app <|> app in - let ops1 = parse_left_associative app (multiply <|> division) in - let ops2 = parse_left_associative ops1 (plus <|> minus) in - let cmp = parse_left_associative ops2 compare in - parse_left_associative cmp (and_op <|> or_op) + let power = parse_right_associative app (parse_custom_infix_when is_custom_power_op) in + let ops1 = + parse_left_associative + power + (parse_custom_infix_when is_custom_mul_op <|> multiply <|> division) + power + in + let ops2 = + parse_left_associative + ops1 + (parse_custom_infix_when is_custom_add_op <|> plus <|> minus) + ops1 + in + let concat = + parse_right_associative ops2 (parse_custom_infix_when is_custom_concat_op) + in + let cmp = + parse_left_associative + concat + (parse_custom_infix_when is_custom_cmp_op + <|> parse_custom_infix_when is_custom_lowest_op + <|> compare) + concat + in + let bool_and = parse_right_associative cmp and_op in + parse_right_associative bool_and or_op ;; let parse_expr = @@ -502,7 +622,7 @@ let parse_expr = ; parse_parens expr ] in - let func = parse_exp_apply term <|> term in + let func = parse_exp_apply term term <|> term in let lst = parse_expr_list func <|> func in let tuple = parse_expr_tuple lst <|> lst in let seq = parse_expr_sequence tuple <|> tuple in diff --git a/EML/lib/middleend/anf.ml b/EML/lib/middleend/anf.ml index 54a5eca3..d23dcba2 100644 --- a/EML/lib/middleend/anf.ml +++ b/EML/lib/middleend/anf.ml @@ -141,6 +141,10 @@ let rec anf (expr : expr) (k : immediate -> anf_expr t) : anf_expr t = | ExpIdent x -> k (ImmediateVar x) | ExpUnarOper (op, expr) -> anf expr (fun imm -> bind_complex_expr (ComplexUnarOper (op, imm)) k) + | ExpBinOper (Custom op_name, exp1, exp2) -> + anf exp1 (fun imm1 -> + anf exp2 (fun imm2 -> + bind_complex_expr (ComplexApp (ImmediateVar op_name, imm1, [ imm2 ])) k)) | ExpBinOper (op, exp1, exp2) -> anf exp1 (fun imm1 -> anf exp2 (fun imm2 -> bind_complex_expr (ComplexBinOper (op, imm1, imm2)) k)) diff --git a/EML/lib/middleend/inferencer.ml b/EML/lib/middleend/inferencer.ml index ebe7b073..3cda300a 100644 --- a/EML/lib/middleend/inferencer.ml +++ b/EML/lib/middleend/inferencer.ml @@ -4,9 +4,6 @@ (* Template: https://gitlab.com/Kakadu/fp2020course-materials/-/tree/master/code/miniml?ref_type=heads*) -(* Inference state is purely immutable: no Hashtbl, no [ref] or [mutable]. We use - [Map] (tree-like) for [var_levels] and thread state through the monad. *) - open Base open Frontend.Ast open Stdlib.Format @@ -287,6 +284,7 @@ module TypeEnv = struct let apply subst env = Map.map env ~f:(Scheme.apply subst) let find = Map.find + let keys = Map.keys let initial_env = let open Base.Map in @@ -459,6 +457,33 @@ let infer_binop_type = function fresh_var >>| fun fresh_ty -> fresh_ty, fresh_ty, TyPrim "bool" | Plus | Minus | Multiply | Division -> return (TyPrim "int", TyPrim "int", TyPrim "int") | And | Or -> return (TyPrim "bool", TyPrim "bool", TyPrim "bool") + | Custom _ -> fail (NoVariable "infer_binop_type: Custom handled in infer_expr") +;; + +(* Returns (arg_ty1, arg_ty2, res_ty, subst_op). For Custom the caller must ensure op_name is in env. *) +let get_binop_arg_res env op = + match op with + | Custom op_name -> + let* op_scheme = + match TypeEnv.find env op_name with + | Some s -> return s + | None -> fail (NoVariable op_name) + in + let* op_ty = instantiate op_scheme in + let* arg_ty1 = fresh_var in + let* arg_ty2 = fresh_var in + let* res_ty = fresh_var in + let* subst = + Substitution.unify op_ty (TyArrow (arg_ty1, TyArrow (arg_ty2, res_ty))) + in + return + ( Substitution.apply subst arg_ty1 + , Substitution.apply subst arg_ty2 + , Substitution.apply subst res_ty + , subst ) + | _ -> + let* ty1, ty2, ty_res = infer_binop_type op in + return (ty1, ty2, ty_res, Substitution.empty) ;; let rec infer_expr env = function @@ -488,11 +513,26 @@ let rec infer_expr env = function | ExpBinOper (op, expr1, expr2) -> let* subst1, ty = infer_expr env expr1 in let* subst2, ty' = infer_expr (TypeEnv.apply subst1 env) expr2 in - let* ty1_op, ty2_op, ty_res = infer_binop_type op in - let* subst3 = Substitution.unify (Substitution.apply subst2 ty) ty1_op in - let* subst4 = Substitution.unify (Substitution.apply subst3 ty') ty2_op in - let* subst = Substitution.compose_all [ subst1; subst2; subst3; subst4 ] in - return (subst, Substitution.apply subst ty_res) + let* arg_ty1, arg_ty2, res_ty, subst_op = + match op with + | Custom op_name when Option.is_none (TypeEnv.find env op_name) -> + (match builtin_op_of_string op_name with + | Some builtin_op -> get_binop_arg_res env builtin_op + | None -> fail (NoVariable op_name)) + | _ -> get_binop_arg_res env op + in + let* subst3 = + Substitution.unify + (Substitution.apply subst2 ty) + (Substitution.apply subst_op arg_ty1) + in + let* subst4 = + Substitution.unify + (Substitution.apply subst3 ty') + (Substitution.apply subst3 arg_ty2) + in + let* subst = Substitution.compose_all [ subst1; subst2; subst_op; subst3; subst4 ] in + return (subst, Substitution.apply subst res_ty) | ExpBranch (cond, then_expr, else_expr) -> let* subst1, ty1 = infer_expr env cond in let* subst2, ty2 = infer_expr (TypeEnv.apply subst1 env) then_expr in @@ -550,14 +590,14 @@ let rec infer_expr env = function (match tys with | [] -> fail (SeveralBounds "inferred empty list type") | ty :: _ -> return (total_subst, TyList ty))) - | ExpLet (NonRec, (PatVariable x, expr1), _, expr2) -> + | ExpLet (NonRec, (PatVariable x, expr1), [], expr2) -> let* () = enter_level in let* subst1, ty1 = infer_expr env expr1 in let* () = leave_level in let env2 = TypeEnv.apply subst1 env in let* ty_gen = generalize env2 ty1 in - let env3 = TypeEnv.extend env x ty_gen in - let* subst2, ty2 = infer_expr (TypeEnv.apply subst1 env3) expr2 in + let env3 = TypeEnv.extend env2 x ty_gen in + let* subst2, ty2 = infer_expr env3 expr2 in let* total_subst = Substitution.compose subst1 subst2 in return (total_subst, ty2) | ExpLet (NonRec, (pattern, expr1), bindings, expr2) -> diff --git a/EML/lib/middleend/inferencer.mli b/EML/lib/middleend/inferencer.mli index f501e0bf..81159a90 100644 --- a/EML/lib/middleend/inferencer.mli +++ b/EML/lib/middleend/inferencer.mli @@ -55,6 +55,7 @@ module TypeEnv : sig val free_vars : t -> VarSet.t val apply : Substitution.t -> t -> t val find : t -> ident -> Scheme.t option + val keys : t -> ident list val initial_env : t val env_with_gc : t end diff --git a/EML/lib/middleend/resolve_builtins.ml b/EML/lib/middleend/resolve_builtins.ml new file mode 100644 index 00000000..8cc860e5 --- /dev/null +++ b/EML/lib/middleend/resolve_builtins.ml @@ -0,0 +1,114 @@ +(** Copyright 2025-2026, Victoria Ostrovskaya & Danil Usoltsev *) + +(** SPDX-License-Identifier: LGPL-3.0-or-later *) + +open Base +open Frontend.Ast + +let names_of_pattern p = + let rec go = function + | PatVariable x -> [ x ] + | PatAny | PatConst _ | PatUnit -> [] + | PatType (q, _) -> go q + | PatTuple (a, b, rest) -> go a @ go b @ List.concat_map rest ~f:go + | PatList ps -> List.concat_map ps ~f:go + | PatOption None -> [] + | PatOption (Some q) -> go q + | PatConstruct (_, None) -> [] + | PatConstruct (_, Some q) -> go q + in + go p +;; + +let names_of_bind (pat, _) = names_of_pattern pat + +let rec resolve_expr scope = function + | ExpBinOper (Custom op, e1, e2) -> + let e1' = resolve_expr scope e1 in + let e2' = resolve_expr scope e2 in + (match List.mem scope op ~equal:String.equal, builtin_op_of_string op with + | true, _ -> ExpBinOper (Custom op, e1', e2') + | false, Some b -> ExpBinOper (b, e1', e2') + | false, None -> ExpBinOper (Custom op, e1', e2')) + | ExpIdent x -> ExpIdent x + | ExpConst c -> ExpConst c + | ExpBranch (c, t, o) -> + ExpBranch + (resolve_expr scope c, resolve_expr scope t, Option.map o ~f:(resolve_expr scope)) + | ExpUnarOper (u, e') -> ExpUnarOper (u, resolve_expr scope e') + | ExpTuple (a, b, rest) -> + ExpTuple + (resolve_expr scope a, resolve_expr scope b, List.map rest ~f:(resolve_expr scope)) + | ExpList es -> ExpList (List.map es ~f:(resolve_expr scope)) + | ExpLambda (pat, pats, body) -> + let scope' = + scope @ names_of_pattern pat @ List.concat_map pats ~f:names_of_pattern + in + ExpLambda (pat, pats, resolve_expr scope' body) + | ExpTypeAnnotation (e', ty) -> ExpTypeAnnotation (resolve_expr scope e', ty) + | ExpLet (rec_flag, (pat, e1), binds, body) -> + let names = + names_of_pattern pat @ List.concat_map binds ~f:(fun (p, _) -> names_of_pattern p) + in + let scope' = scope @ names in + let scope_rhs = + match rec_flag with + | Rec -> scope' + | NonRec -> scope + in + ExpLet + ( rec_flag + , (pat, resolve_expr scope_rhs e1) + , List.map binds ~f:(fun (p, e') -> p, resolve_expr scope_rhs e') + , resolve_expr scope' body ) + | ExpApply (f, a) -> ExpApply (resolve_expr scope f, resolve_expr scope a) + | ExpOption None -> ExpOption None + | ExpOption (Some e') -> ExpOption (Some (resolve_expr scope e')) + | ExpFunction (c, cases) -> + let names = names_of_bind c @ List.concat_map cases ~f:names_of_bind in + let scope' = scope @ names in + ExpFunction + ( (fst c, resolve_expr scope' (snd c)) + , List.map cases ~f:(fun (p, e') -> p, resolve_expr scope' e') ) + | ExpMatch (scrut, c, cases) -> + let names = names_of_bind c @ List.concat_map cases ~f:names_of_bind in + let scope' = scope @ names in + ExpMatch + ( resolve_expr scope scrut + , (fst c, resolve_expr scope' (snd c)) + , List.map cases ~f:(fun (p, e') -> p, resolve_expr scope' e') ) + | ExpConstruct (c, o) -> ExpConstruct (c, Option.map o ~f:(resolve_expr scope)) + | ExpBinOper (b, e1, e2) -> + let left_resolved = resolve_expr scope e1 in + let right_resolved = resolve_expr scope e2 in + let builtin_op_name = builtin_op_to_string b in + if List.mem scope builtin_op_name ~equal:String.equal + then ExpBinOper (Custom builtin_op_name, left_resolved, right_resolved) + else ExpBinOper (b, left_resolved, right_resolved) +;; + +let resolve_structure scope = function + | SEval e -> SEval (resolve_expr scope e), scope + | SValue (rec_flag, (pat, e1), binds) -> + let names = + names_of_pattern pat @ List.concat_map binds ~f:(fun (p, _) -> names_of_pattern p) + in + let scope' = scope @ names in + let scope_rhs = + match rec_flag with + | Rec -> scope' + | NonRec -> scope + in + let e1' = resolve_expr scope_rhs e1 in + let binds' = List.map binds ~f:(fun (p, e') -> p, resolve_expr scope_rhs e') in + SValue (rec_flag, (pat, e1'), binds'), scope' +;; + +let resolve_program (program : program) (initial_scope : string list) : program = + let _, rev_resolved = + List.fold_left program ~init:(initial_scope, []) ~f:(fun (scope, acc) s -> + let s', scope' = resolve_structure scope s in + scope', s' :: acc) + in + List.rev rev_resolved +;; diff --git a/EML/lib/middleend/resolve_builtins.mli b/EML/lib/middleend/resolve_builtins.mli new file mode 100644 index 00000000..8dc462d5 --- /dev/null +++ b/EML/lib/middleend/resolve_builtins.mli @@ -0,0 +1,5 @@ +(** Copyright 2025-2026, Victoria Ostrovskaya & Danil Usoltsev *) + +(** SPDX-License-Identifier: LGPL-3.0-or-later *) + +val resolve_program : Frontend.Ast.program -> string list -> Frontend.Ast.program diff --git a/EML/lib/middleend/runner.ml b/EML/lib/middleend/runner.ml index a093518e..bad7b088 100644 --- a/EML/lib/middleend/runner.ml +++ b/EML/lib/middleend/runner.ml @@ -5,6 +5,7 @@ open Format open Frontend.Ast open Inferencer +open Resolve_builtins open Cc open Ll open Anf @@ -34,7 +35,9 @@ let run (program : program) (env : Inferencer.TypeEnv.t) in env' >>= fun env'' -> - closure_conversion_result program + let initial_scope = Inferencer.TypeEnv.keys env in + let program' = resolve_program program initial_scope in + closure_conversion_result program' |> Result.map_error (fun e -> Closure e) >>= fun cc_ast -> lambda_lifting_result cc_ast diff --git a/EML/lib/utils/pretty_printer.ml b/EML/lib/utils/pretty_printer.ml index 1c46b19a..e02c37a3 100644 --- a/EML/lib/utils/pretty_printer.ml +++ b/EML/lib/utils/pretty_printer.ml @@ -17,6 +17,7 @@ let string_of_bin_op = function | LowerThan -> "<" | Equal -> "=" | NotEqual -> "<>" + | Custom s -> s ;; let string_of_unary_op = function diff --git a/EML/out.ll b/EML/out.ll deleted file mode 100644 index 715a3d8d..00000000 --- a/EML/out.ll +++ /dev/null @@ -1,11 +0,0 @@ -; ModuleID = 'main' -source_filename = "main" -target triple = "x86_64-pc-linux-gnu" - -declare void @print_int(i64) - -define i64 @main() { -entry: - call void @print_int(i64 70) - ret i64 0 -} diff --git a/EML/tests/additional_tests/custom_op_left_associativity.ml b/EML/tests/additional_tests/custom_op_left_associativity.ml new file mode 100644 index 00000000..7f2c23dd --- /dev/null +++ b/EML/tests/additional_tests/custom_op_left_associativity.ml @@ -0,0 +1,6 @@ +let ( =^.^= ) x y = x - y + +let main = + let () = print_int (3 =^.^= 4 =^.^= 5) in + 0 +;; diff --git a/EML/tests/additional_tests/custom_op_pipe.ml b/EML/tests/additional_tests/custom_op_pipe.ml new file mode 100644 index 00000000..b309e49d --- /dev/null +++ b/EML/tests/additional_tests/custom_op_pipe.ml @@ -0,0 +1,7 @@ +let ( ~> ) x f = f x +let succ x = x + 1 + +let main = + let () = print_int (10 ~>succ) in + 0 +;; diff --git a/EML/tests/additional_tests/custom_op_right_associativity.ml b/EML/tests/additional_tests/custom_op_right_associativity.ml new file mode 100644 index 00000000..c5388f27 --- /dev/null +++ b/EML/tests/additional_tests/custom_op_right_associativity.ml @@ -0,0 +1,6 @@ +let ( ** ) x y = x - y + +let main = + let () = print_int (10 ** 3 ** 2) in + 0 +;; diff --git a/EML/tests/additional_tests/custom_op_shadowing.ml b/EML/tests/additional_tests/custom_op_shadowing.ml new file mode 100644 index 00000000..faab0a50 --- /dev/null +++ b/EML/tests/additional_tests/custom_op_shadowing.ml @@ -0,0 +1,5 @@ +let main = + let ( * ) x y = x + y in + let () = print_int (2 * 7) in + 0 +;; diff --git a/EML/tests/additional_tests/custom_op_via_op.ml b/EML/tests/additional_tests/custom_op_via_op.ml new file mode 100644 index 00000000..db25a282 --- /dev/null +++ b/EML/tests/additional_tests/custom_op_via_op.ml @@ -0,0 +1,7 @@ +let ( ** ) x y = x * y +let ( +++ ) x y = (x ** y) + 1 + +let main = + let () = print_int (3 +++ 4) in + 0 +;; diff --git a/EML/tests/anf_tests.ml b/EML/tests/anf_tests.ml index f0caba99..30354196 100644 --- a/EML/tests/anf_tests.ml +++ b/EML/tests/anf_tests.ml @@ -7,11 +7,15 @@ open EML_lib.Middleend.Anf open EML_lib.Middleend.Anf_pp open EML_lib.Middleend.Runner open EML_lib.Middleend.Inferencer +open EML_lib.Middleend.Resolve_builtins + +let initial_scope = TypeEnv.keys TypeEnv.initial_env let parse_and_anf input = match parse input with | Ok ast -> - (match anf_program ast with + let ast' = resolve_program ast initial_scope in + (match anf_program ast' with | Ok anf_ast -> Printf.printf "%s\n" (show_anf_program anf_ast) | Error e -> Printf.printf "ANF error: %s\n" e) | Error e -> Printf.printf "Parsing error: %s\n" e @@ -20,7 +24,8 @@ let parse_and_anf input = let parse_and_anf_pp input = match parse input with | Ok ast -> - (match anf_program ast with + let ast' = resolve_program ast initial_scope in + (match anf_program ast' with | Ok anf_ast -> Printf.printf "%s\n" (anf_to_string anf_ast) | Error e -> Printf.printf "ANF error: %s\n" e) | Error e -> Printf.printf "Parsing error: %s\n" e @@ -299,3 +304,26 @@ let%expect_test "anf_roundtrip_types_partial" = | Error e -> Printf.printf "FAIL: %s\n" e); [%expect {| OK: types preserved after ANF round-trip |}] ;; + +let%expect_test "custom_infix_operator_lowers_to_app" = + parse_and_anf "let ( =^.^= ) x y = (x * 10) + y"; + [%expect + {| +[(AnfValue (NonRec, + ("=^.^=", 2, + (AnfExpr + (ComplexLambda ([(PatVariable "x")], + (AnfExpr + (ComplexLambda ([(PatVariable "y")], + (AnfLet (NonRec, "anf_t0", + (ComplexBinOper (Multiply, (ImmediateVar "x"), + (ImmediateConst (ConstInt 10)))), + (AnfExpr + (ComplexBinOper (Plus, (ImmediateVar "anf_t0"), + (ImmediateVar "y")))) + )) + ))) + )))), + [])) + ]|}] +;; diff --git a/EML/tests/dune b/EML/tests/dune index 7bf49fbb..290c94be 100644 --- a/EML/tests/dune +++ b/EML/tests/dune @@ -13,6 +13,7 @@ (deps (file ../bin/EML.exe) (file Makefile) + (file additional_tests/mangling_test.ml) (file ../lib/runtime/rv64_runtime.a) (source_tree additional_tests) (source_tree gc_tests) @@ -24,6 +25,7 @@ (file ../bin/EML.exe) (file Makefile) (file ../lib/runtime/rv64_runtime.a) + (source_tree additional_tests) (source_tree gc_tests) (source_tree many_tests))) @@ -45,8 +47,6 @@ (source_tree gc_tests) (source_tree many_tests))) -;; LLVM tests require clang to be installed (e.g. apt-get install clang). - (cram (applies_to llvm) (deps diff --git a/EML/tests/infer.t b/EML/tests/infer.t index 08e22066..ea9f1bfb 100644 --- a/EML/tests/infer.t +++ b/EML/tests/infer.t @@ -114,3 +114,23 @@ SPDX-License-Identifier: LGPL-3.0-or-later $ make infer many_tests/do_not_type/099.ml 2>&1 | sed -n '1p' Inferencer error: Left-hand side error: Only variables are allowed on the left-hand side of let rec. + $ make infer additional_tests/custom_op_via_op.ml + val **: int -> int -> int + val +++: int -> int -> int + val main: int + + $ make infer additional_tests/custom_op_left_associativity.ml + val =^.^=: int -> int -> int + val main: int + + $ make infer additional_tests/custom_op_right_associativity.ml + val **: int -> int -> int + val main: int + + $ make infer additional_tests/custom_op_shadowing.ml + val main: int + + $ make infer additional_tests/custom_op_pipe.ml + val main: int + val succ: int -> int + val ~>: t0 -> (t0 -> t2) -> t2 diff --git a/EML/tests/inferencer_tests.ml b/EML/tests/inferencer_tests.ml index 78369081..92a1bb3d 100644 --- a/EML/tests/inferencer_tests.ml +++ b/EML/tests/inferencer_tests.ml @@ -440,3 +440,47 @@ let%expect_test "test_ast_pattern_unit_lambda" = pretty_printer_infer_simple_expression (ExpLambda (PatUnit, [], ExpConst (ConstInt 1))); [%expect {|unit -> int|}] ;; + +let%expect_test "custom_infix_operator" = + pretty_printer_parse_and_infer + {| let ( ** ) x y = x * y +let main = 2 ** 3 |}; + [%expect + {| + val **: int -> int -> int + val main: int|}] +;; + +let%expect_test "custom_infix_bind_like" = + pretty_printer_parse_and_infer + {| let ( >>= ) n _ = if n <= 1 then 1 else n * (n - 1) +let main = 3 >>= 0 |}; + [%expect + {| + val >>=: int -> t1 -> int + val main: int|}] +;; + +let%expect_test "custom_infix_power" = + pretty_printer_parse_and_infer + {| let rec ( ^^ ) x n = if n <= 0 then 1 else x * (x ^^ (n - 1)) +let main = 2 ^^ 10 |}; + [%expect + {| + val ^^: int -> int -> int + val main: int|}] +;; + +let%expect_test "custom_infix_compose" = + pretty_printer_parse_and_infer + {| let ( @@ ) f g = fun x -> f (g x) +let succ x = x + 1 +let double x = x * 2 +let main = (succ @@ double) 10 |}; + [%expect + {| + val @@: (t3 -> t4) -> (t2 -> t3) -> t2 -> t4 + val double: int -> int + val main: int + val succ: int -> int|}] +;; diff --git a/EML/tests/llvm.t b/EML/tests/llvm.t index 6255a5a2..4d081c0a 100644 --- a/EML/tests/llvm.t +++ b/EML/tests/llvm.t @@ -78,3 +78,18 @@ SPDX-License-Identifier: LGPL-3.0-or-later $ make compile_llvm additional_tests/mangling_test.ml 24 + + $ make compile_llvm additional_tests/custom_op_left_associativity.ml + -6 + + $ make compile_llvm additional_tests/custom_op_right_associativity.ml + 9 + + $ make compile_llvm additional_tests/custom_op_shadowing.ml + 9 + + $ make compile_llvm additional_tests/custom_op_via_op.ml + 13 + + $ make compile_llvm additional_tests/custom_op_pipe.ml + 11 diff --git a/EML/tests/llvm_tests.ml b/EML/tests/llvm_tests.ml index f88257c7..f25762fc 100644 --- a/EML/tests/llvm_tests.ml +++ b/EML/tests/llvm_tests.ml @@ -5,12 +5,17 @@ open EML_lib open Frontend.Parser open Middleend.Anf +open Middleend.Inferencer +open Middleend.Resolve_builtins + +let initial_scope = TypeEnv.keys TypeEnv.initial_env let compile_llvm src : string = match parse src with | Error e -> "Parse error: " ^ e | Ok ast -> - (match anf_program ast with + let ast' = resolve_program ast initial_scope in + (match anf_program ast' with | Error e -> "ANF error: " ^ e | Ok anf -> let buf = Buffer.create 4096 in @@ -983,3 +988,124 @@ let%expect_test "codegen closure fn with 10 arg" = attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) } |}] ;; + +let%expect_test "custom op cat" = + compile_llvm_show {|let ( =^.^= ) x y = x - y|}; + [%expect + {| + ; ModuleID = 'EML' + source_filename = "EML" + + declare ptr @eml_applyN(ptr, i64, ptr) + + declare ptr @create_tuple(i64, ptr) + + declare ptr @alloc_closure(ptr, i64) + + declare ptr @field(ptr, i64) + + declare ptr @llvm_call_indirect(ptr, ptr, i64) + + declare void @print_int(i64) + + declare void @init_gc() + + declare void @destroy_gc() + + declare void @set_ptr_stack(ptr) + + declare i64 @get_heap_start() + + declare i64 @get_heap_final() + + declare ptr @collect() + + declare ptr @print_gc_status() + + ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(none) + declare ptr @llvm.frameaddress.p0(i32 immarg) #0 + + define ptr @op__eq_hat_dot_hat_eq(ptr %x, ptr %y) { + entry: + %raw_int = ptrtoint ptr %x to i64 + %minus1 = sub i64 %raw_int, 1 + %untagged = sdiv i64 %minus1, 2 + %raw_int1 = ptrtoint ptr %y to i64 + %minus12 = sub i64 %raw_int1, 1 + %untagged3 = sdiv i64 %minus12, 2 + %sub = sub i64 %untagged, %untagged3 + %twice = mul i64 %sub, 2 + %tagged = add i64 %twice, 1 + %result_int = inttoptr i64 %tagged to ptr + ret ptr %result_int + } + + define ptr @main() { + entry: + ret ptr inttoptr (i64 1 to ptr) + } + + attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) } + |}] +;; + +let%expect_test "custom op pipe" = + compile_llvm_show {|let ( ~> ) x f = f x|}; + [%expect + {| + ; ModuleID = 'EML' + source_filename = "EML" + + declare ptr @eml_applyN(ptr, i64, ptr) + + declare ptr @create_tuple(i64, ptr) + + declare ptr @alloc_closure(ptr, i64) + + declare ptr @field(ptr, i64) + + declare ptr @llvm_call_indirect(ptr, ptr, i64) + + declare void @print_int(i64) + + declare void @init_gc() + + declare void @destroy_gc() + + declare void @set_ptr_stack(ptr) + + declare i64 @get_heap_start() + + declare i64 @get_heap_final() + + declare ptr @collect() + + declare ptr @print_gc_status() + + ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(none) + declare ptr @llvm.frameaddress.p0(i32 immarg) #0 + + define ptr @op__tilde_gt(ptr %x, ptr %f) { + entry: + br label %apply_step_0 + + merge_0: ; preds = %apply_step_0 + %apply_result = phi ptr [ %apply_step_01, %apply_step_0 ] + ret ptr %apply_result + + apply_step_0: ; preds = %entry + %apply_one = alloca [1 x ptr], align 8 + %one_elem = getelementptr [1 x ptr], ptr %apply_one, i32 0, i32 0 + store ptr %x, ptr %one_elem, align 8 + %apply_step_01 = call ptr @eml_applyN(ptr %f, i64 1, ptr %one_elem) + br label %merge_0 + } + + define ptr @main() { + entry: + ret ptr inttoptr (i64 1 to ptr) + } + + attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) } + |}] +;; diff --git a/EML/tests/parser_tests.ml b/EML/tests/parser_tests.ml index 5af3cc27..c922b30d 100644 --- a/EML/tests/parser_tests.ml +++ b/EML/tests/parser_tests.ml @@ -231,3 +231,39 @@ let%expect_test "test_unit" = ] |}] ;; + +let%expect_test "custom_infix_operator" = + parse_test {| let ( ** ) x y = x * y in 2 ** 3 |}; + [%expect + {| +[(SEval + (ExpLet (NonRec, + ((PatVariable "**"), + (ExpLambda ((PatVariable "x"), [(PatVariable "y")], + (ExpBinOper (Multiply, (ExpIdent "x"), (ExpIdent "y")))))), + [], + (ExpBinOper ((Custom "**"), (ExpConst (ConstInt 2)), + (ExpConst (ConstInt 3)))) + ))) + ] + |}] +;; + +let%expect_test "custom_power_operator_is_right_associative" = + parse_test {| let ( ** ) x y = x * y in 2 ** 3 ** 4 |}; + [%expect + {| +[(SEval + (ExpLet (NonRec, + ((PatVariable "**"), + (ExpLambda ((PatVariable "x"), [(PatVariable "y")], + (ExpBinOper (Multiply, (ExpIdent "x"), (ExpIdent "y")))))), + [], + (ExpBinOper ((Custom "**"), (ExpConst (ConstInt 2)), + (ExpBinOper ((Custom "**"), (ExpConst (ConstInt 3)), + (ExpConst (ConstInt 4)))) + )) + ))) + ] +|}] +;; diff --git a/EML/tests/riscv.t b/EML/tests/riscv.t index 228b833d..b2cc6a06 100644 --- a/EML/tests/riscv.t +++ b/EML/tests/riscv.t @@ -79,3 +79,18 @@ SPDX-License-Identifier: LGPL-3.0-or-later $ make compile_riscv additional_tests/mangling_test.ml 24 + + $ make compile_riscv additional_tests/custom_op_left_associativity.ml + -6 + + $ make compile_riscv additional_tests/custom_op_right_associativity.ml + 9 + + $ make compile_riscv additional_tests/custom_op_shadowing.ml + 9 + + $ make compile_riscv additional_tests/custom_op_via_op.ml + 13 + + $ make compile_riscv additional_tests/custom_op_pipe.ml + 11 diff --git a/EML/tests/riscv_peephole_tests.ml b/EML/tests/riscv_peephole_tests.ml new file mode 100644 index 00000000..44e8b2b8 --- /dev/null +++ b/EML/tests/riscv_peephole_tests.ml @@ -0,0 +1,630 @@ +(** Copyright 2025-2026, Victoria Ostrovskaya & Danil Usoltsev *) + +(** SPDX-License-Identifier: LGPL-3.0-or-later *) + +open EML_lib +open Frontend.Parser +open Middleend.Anf +open Middleend.Inferencer +open Middleend.Resolve_builtins +open Backend.Ricsv.Architecture +open Riscv_backend + +let print_instrs instructions = + let rendered = + List.map (fun instruction -> Format.asprintf "%a" pp_instr instruction) instructions + in + print_endline (String.concat "\n" rendered) +;; + +let compile_riscv ?(enable_peephole = true) src = + match parse src with + | Error e -> "Parse error: " ^ e + | Ok ast -> + let scope = TypeEnv.keys TypeEnv.initial_env in + let ast' = resolve_program ast scope in + (match anf_program ast' with + | Error e -> "ANF error: " ^ e + | Ok anf -> + let buf = Buffer.create 1024 in + let ppf = Format.formatter_of_buffer buf in + (match Backend.Ricsv.Runner.gen_program ~enable_peephole ppf anf with + | Ok () -> + Format.pp_print_flush ppf (); + Buffer.contents buf + | Error e -> "Codegen error: " ^ e)) +;; + +let show_diff ~input ~output value = + print_endline "=== Without peepholes ==="; + input value; + print_endline ""; + print_endline "=== With peepholes ==="; + output value +;; + +let show_codogen_diff src = + show_diff + ~input:(fun source -> print_endline (compile_riscv ~enable_peephole:false source)) + ~output:(fun source -> print_endline (compile_riscv ~enable_peephole:true source)) + src +;; + +let show_instr_diff instrs = + show_diff + ~input:print_instrs + ~output:(fun instructions -> + instructions |> Backend.Ricsv.Peephole.optimize |> print_instrs) + instrs +;; + +let%expect_test "optimizes repeated stack load pattern from task description" = + let input = + [ Li (T 0, 1) + ; Ld (T 1, (SP, 64)) + ; Add (T 0, T 1, T 0) + ; Sd (T 1, (SP, 64)) + ; Li (T 0, 2) + ; Ld (T 1, (SP, 64)) + ; Mul (T 0, T 1, T 0) + ; Sd (T 1, (SP, 64)) + ] + in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +li t0, 1 +ld t1, 64(sp) +add t0, t1, t0 +sd t1, 64(sp) +li t0, 2 +ld t1, 64(sp) +mul t0, t1, t0 +sd t1, 64(sp) + +=== With peepholes === +li t0, 1 +ld t1, 64(sp) +slli t0, t1, 1 +|}] +;; + +let%expect_test "removes redundant load and forwards store to load" = + let input = + [ Ld (T 0, (SP, 64)); Ld (T 1, (SP, 64)); Sd (T 1, (SP, 64)); Ld (A 0, (SP, 64)) ] + in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +ld t0, 64(sp) +ld t1, 64(sp) +sd t1, 64(sp) +ld a0, 64(sp) + +=== With peepholes === +ld t0, 64(sp) +mv t1, t0 +sd t1, 64(sp) +mv a0, t1 +|}] +;; + +let%expect_test "folds addi chain and removes dead overwrite" = + let input = + [ Addi (SP, SP, -16); Addi (SP, SP, 8); Li (T 0, 1); Li (T 0, 2); Addi (T 1, T 1, 0) ] + in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +addi sp, sp, -16 +addi sp, sp, 8 +li t0, 1 +li t0, 2 +addi t1, t1, 0 + +=== With peepholes === +addi sp, sp, -8 +li t0, 2 +|}] +;; + +let%expect_test "drops jump to the immediately following label" = + let input = [ J "l1"; Label "l1"; Mv (A 0, A 0); Ret ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +j l1 +l1: +mv a0, a0 +ret + +=== With peepholes === +ret +|}] +;; + +let%expect_test "collapses double copy before binary op" = + let input = [ Mv (T 0, A 0); Mv (T 1, A 0); Add (A 0, T 0, T 1) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +mv t0, a0 +mv t1, a0 +add a0, t0, t1 + +=== With peepholes === +add a0, a0, a0 +|}] +;; + +let%expect_test "propagates single mv into following consumer" = + let input = [ Mv (T 0, A 0); Li (T 1, 1); Slt (A 0, T 0, T 1) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +mv t0, a0 +li t1, 1 +slt a0, t0, t1 + +=== With peepholes === +li t1, 1 +slt a0, a0, t1 +|}] +;; + +let%expect_test "rewrites li plus add into addi" = + let input = [ Li (T 1, 1); Add (A 0, T 0, T 1) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +li t1, 1 +add a0, t0, t1 + +=== With peepholes === +addi a0, t0, 1 +|}] +;; + +let%expect_test "folds li plus add when destination is constant register" = + let input = [ Li (T 0, 1); Add (T 0, T 1, T 0) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +li t0, 1 +add t0, t1, t0 + +=== With peepholes === +addi t0, t1, 1 +|}] +;; + +let%expect_test "rewrites mul by power of two into slli" = + let input = [ Li (T 0, 4); Mul (A 0, T 1, T 0) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +li t0, 4 +mul a0, t1, t0 + +=== With peepholes === +slli a0, t1, 2 +|}] +;; + +let%expect_test "keeps load cache barriers on call" = + let input = + [ Ld (T 0, (SP, 64)); Call "foo"; Ld (T 1, (SP, 64)); Add (A 0, T 0, T 1) ] + in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +ld t0, 64(sp) +call foo +ld t1, 64(sp) +add a0, t0, t1 + +=== With peepholes === +ld t0, 64(sp) +call foo +ld t1, 64(sp) +add a0, t0, t1 +|}] +;; + +let%expect_test "forwards store to following load on same slot" = + let input = [ Sd (A 0, (fp, -16)); Ld (T 0, (fp, -16)); Add (A 0, T 0, A 1) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +sd a0, -16(fp) +ld t0, -16(fp) +add a0, t0, a1 + +=== With peepholes === +sd a0, -16(fp) +add a0, a0, a1 +|}] +;; + +let%expect_test "folds constant beq into jump" = + let input = [ Li (T 0, 1); Li (T 1, 1); Beq (T 0, T 1, "else_1") ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +li t0, 1 +li t1, 1 +beq t0, t1, else_1 + +=== With peepholes === +li t0, 1 +li t1, 1 +j else_1 +|}] +;; + +let%expect_test "removes dead store before ret in same block" = + let input = [ Sd (A 0, (fp, -8)); Add (A 0, A 0, A 1); Ret ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +sd a0, -8(fp) +add a0, a0, a1 +ret + +=== With peepholes === +add a0, a0, a1 +ret +|}] +;; + +let%expect_test "keeps store before call barrier" = + let input = [ Sd (A 0, (fp, -8)); Call "foo"; Ret ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +sd a0, -8(fp) +call foo +ret + +=== With peepholes === +sd a0, -8(fp) +call foo +ret +|}] +;; + +let%expect_test "removes store that restores unchanged loaded slot value" = + let input = [ Ld (T 1, (sp, 64)); Add (T 0, T 1, T 0); Sd (T 1, (sp, 64)) ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +ld t1, 64(sp) +add t0, t1, t0 +sd t1, 64(sp) + +=== With peepholes === +ld t1, 64(sp) +add t0, t1, t0 +|}] +;; + +let%expect_test "drops overwritten store before next store to same slot" = + let input = [ Sd (A 0, (fp, -8)); Add (A 0, A 0, A 1); Sd (T 0, (fp, -8)); Ret ] in + show_instr_diff input; + [%expect + {| +=== Without peepholes === +sd a0, -8(fp) +add a0, a0, a1 +sd t0, -8(fp) +ret + +=== With peepholes === +add a0, a0, a1 +ret +|}] +;; + +let%expect_test "shows code with and without peephole 1" = + let src = + {| + let f x = + let y = x < 0 in + y + 1 + let main = f 1 + |} + in + show_codogen_diff src; + [%expect + {| + === Without peepholes === + .section .text + .globl f + .type f, @function + f: + addi sp, sp, -24 + sd ra, 16(sp) + sd fp, 8(sp) + addi fp, sp, 8 + sd a0, -8(fp) + ld t0, -8(fp) + li t1, 1 + slt a0, t0, t1 + add a0, a0, a0 + addi a0, a0, 1 + sd a0, -16(fp) + ld t0, -16(fp) + li t1, 3 + add a0, t0, t1 + addi a0, a0, -1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + li a0, 3 + call f + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret + + + === With peepholes === + .section .text + .globl f + .type f, @function + f: + addi sp, sp, -24 + sd ra, 16(sp) + sd fp, 8(sp) + addi fp, sp, 8 + sd a0, -8(fp) + li t1, 1 + slt a0, a0, t1 + add a0, a0, a0 + addi a0, a0, 1 + sd a0, -16(fp) + addi a0, a0, 2 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + li a0, 3 + call f + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret |}] +;; + +let%expect_test "shows code with and without peephole 2" = + let src = + {| + let f y = + let x = 1 + y in + let z = 2 * y in + x + z + let main = f 10 + |} + in + show_codogen_diff src; + [%expect + {| + === Without peepholes === + .section .text + .globl f + .type f, @function + f: + addi sp, sp, -32 + sd ra, 24(sp) + sd fp, 16(sp) + addi fp, sp, 16 + sd a0, -8(fp) + li t0, 3 + ld t1, -8(fp) + add a0, t0, t1 + addi a0, a0, -1 + sd a0, -16(fp) + li t0, 5 + ld t1, -8(fp) + srli t0, t0, 1 + addi t1, t1, -1 + mul a0, t0, t1 + addi a0, a0, 1 + sd a0, -24(fp) + ld t0, -16(fp) + ld t1, -24(fp) + add a0, t0, t1 + addi a0, a0, -1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + li a0, 21 + call f + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret + + + === With peepholes === + .section .text + .globl f + .type f, @function + f: + addi sp, sp, -32 + sd ra, 24(sp) + sd fp, 16(sp) + addi fp, sp, 16 + sd a0, -8(fp) + li t0, 3 + mv t1, a0 + add a0, t0, t1 + addi a0, a0, -1 + sd a0, -16(fp) + li t0, 5 + srli t0, t0, 1 + addi t1, t1, -1 + mul a0, t0, t1 + addi a0, a0, 1 + sd a0, -24(fp) + ld t0, -16(fp) + add a0, t0, a0 + addi a0, a0, -1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + li a0, 21 + call f + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret |}] +;; + +let%expect_test "shows code with and without peephole 3" = + let src = + {| + let g x = + let a = x + 1 in + let b = a + 1 in + b + 1 + let main = g 1 + |} + in + show_codogen_diff src; + [%expect + {| + === Without peepholes === + .section .text + .globl g + .type g, @function + g: + addi sp, sp, -32 + sd ra, 24(sp) + sd fp, 16(sp) + addi fp, sp, 16 + sd a0, -8(fp) + ld t0, -8(fp) + li t1, 3 + add a0, t0, t1 + addi a0, a0, -1 + sd a0, -16(fp) + ld t0, -16(fp) + li t1, 3 + add a0, t0, t1 + addi a0, a0, -1 + sd a0, -24(fp) + ld t0, -24(fp) + li t1, 3 + add a0, t0, t1 + addi a0, a0, -1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + li a0, 3 + call g + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret + + + === With peepholes === + .section .text + .globl g + .type g, @function + g: + addi sp, sp, -32 + sd ra, 24(sp) + sd fp, 16(sp) + addi fp, sp, 16 + sd a0, -8(fp) + addi a0, a0, 2 + sd a0, -16(fp) + addi a0, a0, 2 + sd a0, -24(fp) + addi a0, a0, 2 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + li a0, 3 + call g + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret |}] +;; diff --git a/EML/tests/riscv_peephole_tests.mli b/EML/tests/riscv_peephole_tests.mli new file mode 100644 index 00000000..ff96571d --- /dev/null +++ b/EML/tests/riscv_peephole_tests.mli @@ -0,0 +1,3 @@ +(** Copyright 2025-2026, Victoria Ostrovskaya & Danil Usoltsev *) + +(** SPDX-License-Identifier: LGPL-3.0-or-later *) diff --git a/EML/tests/riscv_tests.ml b/EML/tests/riscv_tests.ml index 279e862c..efe0d4e3 100644 --- a/EML/tests/riscv_tests.ml +++ b/EML/tests/riscv_tests.ml @@ -5,12 +5,16 @@ open EML_lib open Frontend.Parser open Middleend.Anf +open Middleend.Inferencer +open Middleend.Resolve_builtins let compile src : string = match parse src with | Error e -> "Parse error: " ^ e | Ok ast -> - (match anf_program ast with + let scope = TypeEnv.keys TypeEnv.initial_env in + let ast' = resolve_program ast scope in + (match anf_program ast' with | Error e -> "ANF error: " ^ e | Ok anf -> let buf = Buffer.create 1024 in @@ -35,7 +39,7 @@ x: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp li t0, 11 li a0, 1 sub a0, a0, t0 @@ -50,7 +54,7 @@ main: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp li a0, 1 addi sp, fp, 16 ld ra, 8(fp) @@ -71,7 +75,7 @@ x: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp li t0, 3 xori a0, t0, 3 addi sp, fp, 16 @@ -85,7 +89,7 @@ main: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp li a0, 1 addi sp, fp, 16 ld ra, 8(fp) @@ -106,7 +110,7 @@ main: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp li a0, 1 addi sp, fp, 16 ld ra, 8(fp) @@ -127,7 +131,7 @@ main: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp li t0, 15 li t1, 17 srli t0, t0, 1 @@ -157,11 +161,9 @@ double: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp sd a0, -8(fp) - ld t0, -8(fp) - ld t1, -8(fp) - add a0, t0, t1 + add a0, a0, a0 addi a0, a0, -1 addi sp, fp, 16 ld ra, 8(fp) @@ -202,13 +204,13 @@ abs: sd fp, 16(sp) addi fp, sp, 16 sd a0, -8(fp) - ld t0, -8(fp) + mv t0, a0 li t1, 1 slt a0, t0, t1 add a0, a0, a0 addi a0, a0, 1 sd a0, -16(fp) - ld t0, -16(fp) + mv t0, a0 li t1, 1 beq t0, t1, else_0 ld t0, -8(fp) @@ -256,12 +258,10 @@ sq: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp sd a0, -8(fp) - ld t0, -8(fp) - ld t1, -8(fp) - srli t0, t0, 1 - addi t1, t1, -1 + srli t0, a0, 1 + addi t1, a0, -1 mul a0, t0, t1 addi a0, a0, 1 addi sp, fp, 16 @@ -278,15 +278,13 @@ sum_of_squares: addi fp, sp, 384 sd a0, -8(fp) sd a1, -16(fp) - ld a0, -8(fp) call sq sd a0, -24(fp) ld a0, -16(fp) call sq sd a0, -32(fp) ld t0, -24(fp) - ld t1, -32(fp) - add a0, t0, t1 + add a0, t0, a0 addi a0, a0, -1 addi sp, fp, 16 ld ra, 8(fp) @@ -328,13 +326,13 @@ fib: sd fp, 416(sp) addi fp, sp, 416 sd a0, -8(fp) - ld t0, -8(fp) + mv t0, a0 li t1, 5 slt a0, t0, t1 add a0, a0, a0 addi a0, a0, 1 sd a0, -16(fp) - ld t0, -16(fp) + mv t0, a0 li t1, 1 beq t0, t1, else_0 li a0, 3 @@ -345,7 +343,6 @@ else_0: sub a0, t0, t1 addi a0, a0, 1 sd a0, -24(fp) - ld a0, -24(fp) call fib sd a0, -32(fp) ld t0, -8(fp) @@ -353,12 +350,10 @@ else_0: sub a0, t0, t1 addi a0, a0, 1 sd a0, -40(fp) - ld a0, -40(fp) call fib sd a0, -48(fp) ld t0, -32(fp) - ld t1, -48(fp) - add a0, t0, t1 + add a0, t0, a0 addi a0, a0, -1 end_0: addi sp, fp, 16 @@ -398,11 +393,10 @@ is_positive: addi sp, sp, -16 sd ra, 8(sp) sd fp, 0(sp) - addi fp, sp, 0 + mv fp, sp sd a0, -8(fp) - ld t0, -8(fp) li t1, 1 - slt a0, t1, t0 + slt a0, t1, a0 add a0, a0, a0 addi a0, a0, 1 addi sp, fp, 16 @@ -446,17 +440,13 @@ mul3: sd a0, -8(fp) sd a1, -16(fp) sd a2, -24(fp) - ld t0, -8(fp) - ld t1, -16(fp) - srli t0, t0, 1 - addi t1, t1, -1 + srli t0, a0, 1 + addi t1, a1, -1 mul a0, t0, t1 addi a0, a0, 1 sd a0, -32(fp) - ld t0, -32(fp) - ld t1, -24(fp) - srli t0, t0, 1 - addi t1, t1, -1 + srli t0, a0, 1 + addi t1, a2, -1 mul a0, t0, t1 addi a0, a0, 1 addi sp, fp, 16 @@ -506,15 +496,14 @@ let%expect_test "test1" = addi fp, sp, 384 sd a0, -8(fp) li t0, 1 - ld t1, -8(fp) + mv t1, a0 xor a0, t0, t1 snez a0, a0 add a0, a0, a0 addi a0, a0, 1 sd a0, -16(fp) - ld t0, -16(fp) li t1, 1 - beq t0, t1, else_0 + beq a0, t1, else_0 li a0, 1 call print_int j end_0 @@ -536,7 +525,7 @@ let%expect_test "test1" = addi fp, sp, 424 li t0, 1 li t1, 1 - beq t0, t1, else_1 + j else_1 li a0, 1 j end_1 else_1: @@ -546,25 +535,22 @@ let%expect_test "test1" = li a0, 3 end_1: sd a0, -16(fp) - ld t0, -16(fp) li t1, 1 - beq t0, t1, else_2 + beq a0, t1, else_2 li a0, 1 j end_2 else_2: li a0, 3 end_2: sd a0, -24(fp) - ld t0, -24(fp) li t1, 1 - beq t0, t1, else_3 + beq a0, t1, else_3 li a0, 1 j end_3 else_3: li a0, 3 end_3: sd a0, -32(fp) - ld a0, -32(fp) call large addi sp, fp, 16 ld ra, 8(fp) @@ -603,34 +589,22 @@ let%expect_test "codegen closure fn with 10 arg" = sd a4, -40(fp) sd a5, -48(fp) sd a6, -56(fp) - ld t0, -8(fp) - ld t1, -16(fp) - add a0, t0, t1 + add a0, a0, a1 addi a0, a0, -1 sd a0, -64(fp) - ld t0, -64(fp) - ld t1, -24(fp) - add a0, t0, t1 + add a0, a0, a2 addi a0, a0, -1 sd a0, -72(fp) - ld t0, -72(fp) - ld t1, -32(fp) - add a0, t0, t1 + add a0, a0, a3 addi a0, a0, -1 sd a0, -80(fp) - ld t0, -80(fp) - ld t1, -40(fp) - add a0, t0, t1 + add a0, a0, a4 addi a0, a0, -1 sd a0, -88(fp) - ld t0, -88(fp) - ld t1, -48(fp) - add a0, t0, t1 + add a0, a0, a5 addi a0, a0, -1 sd a0, -96(fp) - ld t0, -96(fp) - ld t1, -56(fp) - add a0, t0, t1 + add a0, a0, a6 addi a0, a0, -1 addi sp, fp, 16 ld ra, 8(fp) @@ -661,7 +635,6 @@ let%expect_test "codegen closure fn with 10 arg" = call eml_applyN addi sp, sp, 32 sd a0, -8(fp) - ld a0, -8(fp) li a1, 2 addi sp, sp, -16 li t0, 3 @@ -672,7 +645,6 @@ let%expect_test "codegen closure fn with 10 arg" = call eml_applyN addi sp, sp, 16 sd a0, -16(fp) - ld a0, -16(fp) li a1, 2 addi sp, sp, -16 li t0, 3 @@ -683,7 +655,6 @@ let%expect_test "codegen closure fn with 10 arg" = call eml_applyN addi sp, sp, 16 sd a0, -24(fp) - ld a0, -24(fp) call print_int addi sp, fp, 16 ld ra, 8(fp) @@ -692,3 +663,82 @@ let%expect_test "codegen closure fn with 10 arg" = ret |}] ;; + +let%expect_test "custom op cat" = + run {|let ( =^.^= ) x y = x - y|}; + [%expect + {| + .section .text + .globl op__eq_hat_dot_hat_eq + .type op__eq_hat_dot_hat_eq, @function + op__eq_hat_dot_hat_eq: + addi sp, sp, -16 + sd ra, 8(sp) + sd fp, 0(sp) + mv fp, sp + sd a0, -8(fp) + sd a1, -16(fp) + sub a0, a0, a1 + addi a0, a0, 1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -16 + sd ra, 8(sp) + sd fp, 0(sp) + mv fp, sp + li a0, 1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret + |}] +;; + +let%expect_test "custom op pipe" = + run {|let ( ~> ) x f = f x|}; + [%expect + {| + .section .text + .globl op__tilde_gt + .type op__tilde_gt, @function + op__tilde_gt: + addi sp, sp, -200 + sd ra, 192(sp) + sd fp, 184(sp) + addi fp, sp, 184 + sd a0, -8(fp) + sd a1, -16(fp) + mv a0, a1 + li a1, 1 + addi sp, sp, -8 + ld t0, -8(fp) + sd t0, 0(sp) + mv a2, sp + call eml_applyN + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + ret + + .globl main + .type main, @function + main: + addi sp, sp, -16 + sd ra, 8(sp) + sd fp, 0(sp) + mv fp, sp + li a0, 1 + addi sp, fp, 16 + ld ra, 8(fp) + ld fp, 0(fp) + li a0, 0 + ret + |}] +;;