diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 14bbd07eca29..7a7dcd5a577a 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -324,6 +324,8 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } + + f = WithAttr(std::move(f), "horizontal_fuse", Bool(true)); GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false); return IRModule(Map({{global_var, f}})); } @@ -360,6 +362,10 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_ if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } + + + f = WithAttr(std::move(f), "horizontal_fuse", Bool(true)); + IRModule mod = IRModule(Map({{GlobalVar(name), f}})); // Get the pass list