diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 82ddb28de4337..896499e0d64da 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -823,6 +823,20 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; + auto setter_for_xla_cpu_opt_preset = + [debug_options](absl::string_view input) { + std::string upper_input = absl::AsciiStrToUpper(input); + if (!absl::StartsWith(upper_input, "CPU_OPT_PRESET_")) { + upper_input = absl::StrCat("CPU_OPT_PRESET_", upper_input); + } + DebugOptions::CpuOptPreset preset; + if (!DebugOptions::CpuOptPreset_Parse(upper_input, &preset)) { + return false; + } + debug_options->set_xla_cpu_opt_preset(preset); + return true; + }; + auto setter_for_xla_cpu_enable_concurrency_optimized_scheduler = [debug_options](bool value) { if (value) { @@ -1219,6 +1233,10 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_cpu_use_acl", bool_setter_for(&DebugOptions::set_xla_cpu_use_acl), debug_options->xla_cpu_use_acl(), "Generate calls to ACL (Arm Compute Library) in the CPU backend.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_opt_preset", setter_for_xla_cpu_opt_preset, + DebugOptions::CpuOptPreset_Name(debug_options->xla_cpu_opt_preset()), + "Set CPU optimization preset (FAST_RUNTIME, FAST_COMPILE)")); flag_list->push_back( tsl::Flag("xla_cpu_use_fusion_emitters", bool_setter_for(&DebugOptions::set_xla_cpu_use_fusion_emitters), diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 6ecca8af0e4eb..2760e1003d284 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -580,8 +580,12 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( const int64_t num_partitions = module->config().num_partitions(); const bool use_fusion_emitters = module->config().debug_options().xla_cpu_use_fusion_emitters(); - bool use_shardy_partitioner = module->config().use_shardy_partitioner(); - bool flatten_before_fusion = !options::FlattenAfterFusion(module->config()); + const bool use_shardy_partitioner = module->config().use_shardy_partitioner(); + const bool fast_compile = + module->config().debug_options().xla_cpu_opt_preset() == + xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE; + const bool flatten_before_fusion = + !options::FlattenAfterFusion(module->config()) && !fast_compile; // Replace asynchronous collectives with synchronous ones. HloPassPipeline async_collective_pipeline("async-collective"); @@ -873,6 +877,9 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( } #ifdef XLA_ONEDNN + // We must re-obtain the reference to debug_options because the previous + // reference is invalidated by the pipeline passes that run before this + // lambda. const DebugOptions& debug_options = module->config().debug_options(); if ((debug_options.xla_cpu_use_onednn() || debug_options.xla_cpu_experimental_onednn_custom_call()) && @@ -964,6 +971,10 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( const auto& debug_options = module->config().debug_options(); const bool use_fusion_emitters = debug_options.xla_cpu_use_fusion_emitters(); bool flatten_after_fusion = options::FlattenAfterFusion(module->config()); + if (debug_options.xla_cpu_opt_preset() == + xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE) { + flatten_after_fusion = true; + } HloPassPipeline pipeline("HLO passes after layout assignment"); { @@ -2101,21 +2112,35 @@ absl::StatusOr> CpuCompiler::RunBackend( options.cpu_target_config->cpu_target_machine_options.value(); } + const bool fast_compile = + module->config().debug_options().xla_cpu_opt_preset() == + xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE; + + llvm::CodeGenOptLevel opt_level = + IrCompiler::GetCodeGenOptLevel(module->config()); + if (fast_compile && + !module->config().debug_options().has_xla_backend_optimization_level()) { + opt_level = llvm::CodeGenOptLevel::Less; + } + + const bool disable_expensive_passes = + module->config().debug_options().xla_llvm_disable_expensive_passes(); + // Options for compiling LLVM IR to machine code. IrCompiler::Options ir_compiler_options{ - /*optimization_level=*/IrCompiler::GetCodeGenOptLevel(module->config()), - /*optimize_for_size=*/options::OptimizeForSizeRequested(module->config()), + /*optimization_level=*/opt_level, + /*optimize_for_size=*/ + options::OptimizeForSizeRequested(module->config()), /*target_machine_options=*/ target_machine_options, /*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(module->config()), - /*disable_expensive_passes=*/ - module->config().debug_options().xla_llvm_disable_expensive_passes(), - /*slp_vectorizer_disabled=*/ + /*disable_expensive_passes=*/disable_expensive_passes, + /*disable_slp_vectorizer=*/ options::SlpVectorizerDisabled(module->config()), /*disable_loop_unrolling=*/ options::DisableLoopUnrolling(module->config()), /*disable_platform_dependent_math=*/ - options::DisablePlatformDependentMath(module->config()), + options::DisablePlatformDependentMath(module->config()) || fast_compile, }; ThunkEmitter::Options thunk_emitter_options = { @@ -2244,20 +2269,32 @@ CpuCompiler::CompileAheadOfTimeThunks( triple.normalize(), target_machine->getTargetCPU(), target_machine->getTargetFeatureString()); + const bool fast_compile = + module->config().debug_options().xla_cpu_opt_preset() == + xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE; + + llvm::CodeGenOptLevel opt_level = target_machine->getOptLevel(); + if (fast_compile && + !module->config().debug_options().has_xla_backend_optimization_level()) { + opt_level = llvm::CodeGenOptLevel::Less; + } + + const bool disable_expensive_passes = + module->config().debug_options().xla_llvm_disable_expensive_passes(); + IrCompiler::Options ir_compiler_options = { - /*optimization_level=*/target_machine->getOptLevel(), + /*optimization_level=*/opt_level, /*optimize_for_size=*/ options::OptimizeForSizeRequested(module->config()), /*target_machine_options=*/target_machine_options, /*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(module->config()), - /*disable_expensive_passes=*/ - module->config().debug_options().xla_llvm_disable_expensive_passes(), + /*disable_expensive_passes=*/disable_expensive_passes, /*disable_slp_vectorizer=*/ options::SlpVectorizerDisabled(module->config()), /*disable_loop_unrolling=*/ options::DisableLoopUnrolling(module->config()), /*disable_platform_dependent_math=*/ - options::DisablePlatformDependentMath(module->config()), + options::DisablePlatformDependentMath(module->config()) || fast_compile, /*dfsan_enabled=*/aot_options.sanitize_dataflow(), /*dfsan_abilists_enabled=*/aot_options.sanitize_abilists_dataflow()}; diff --git a/xla/xla.proto b/xla/xla.proto index 29d4b80b67255..e9f270843e433 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -180,6 +180,13 @@ message DebugOptions { //--------------------------------------------------------------------------// // XLA:CPU options. //--------------------------------------------------------------------------// + + enum CpuOptPreset { + CPU_OPT_PRESET_DEFAULT = 0; + CPU_OPT_PRESET_FAST_RUNTIME = 1; + CPU_OPT_PRESET_FAST_COMPILE = 2; + } + // clang-format off // go/keep-sorted start newline_separated=yes skip_lines=1 ignore_prefixes=["optional bool","optional int32","optional string", "optional XnnGraphFusionMode", "repeated LibraryFusionType", "optional CpuSchedulerType"] // clang-format on @@ -208,6 +215,9 @@ message DebugOptions { XNN_GRAPH_FUSION_MODE_BYPASS_COST_MODEL = 3; } + // XLA:CPU optimization preset. + optional CpuOptPreset xla_cpu_opt_preset = 466; + // The number of seconds to wait before terminating a rendezvous call optional int32 xla_cpu_collective_call_terminate_timeout_seconds = 417; @@ -1443,7 +1453,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 466 + // Next id: 467 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.