Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,20 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
return true;
};

auto setter_for_xla_cpu_opt_preset =
[debug_options](absl::string_view input) {
std::string upper_input = absl::AsciiStrToUpper(input);
if (!absl::StartsWith(upper_input, "CPU_OPT_PRESET_")) {
upper_input = absl::StrCat("CPU_OPT_PRESET_", upper_input);
}
DebugOptions::CpuOptPreset preset;
if (!DebugOptions::CpuOptPreset_Parse(upper_input, &preset)) {
return false;
}
debug_options->set_xla_cpu_opt_preset(preset);
return true;
};

auto setter_for_xla_cpu_enable_concurrency_optimized_scheduler =
[debug_options](bool value) {
if (value) {
Expand Down Expand Up @@ -1219,6 +1233,10 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_cpu_use_acl", bool_setter_for(&DebugOptions::set_xla_cpu_use_acl),
debug_options->xla_cpu_use_acl(),
"Generate calls to ACL (Arm Compute Library) in the CPU backend."));
flag_list->push_back(tsl::Flag(
"xla_cpu_opt_preset", setter_for_xla_cpu_opt_preset,
DebugOptions::CpuOptPreset_Name(debug_options->xla_cpu_opt_preset()),
"Set CPU optimization preset (FAST_RUNTIME, FAST_COMPILE)"));
flag_list->push_back(
tsl::Flag("xla_cpu_use_fusion_emitters",
bool_setter_for(&DebugOptions::set_xla_cpu_use_fusion_emitters),
Expand Down
61 changes: 49 additions & 12 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,12 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
const int64_t num_partitions = module->config().num_partitions();
const bool use_fusion_emitters =
module->config().debug_options().xla_cpu_use_fusion_emitters();
bool use_shardy_partitioner = module->config().use_shardy_partitioner();
bool flatten_before_fusion = !options::FlattenAfterFusion(module->config());
const bool use_shardy_partitioner = module->config().use_shardy_partitioner();
const bool fast_compile =
module->config().debug_options().xla_cpu_opt_preset() ==
xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE;
const bool flatten_before_fusion =
!options::FlattenAfterFusion(module->config()) && !fast_compile;

// Replace asynchronous collectives with synchronous ones.
HloPassPipeline async_collective_pipeline("async-collective");
Expand Down Expand Up @@ -873,6 +877,9 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
}

#ifdef XLA_ONEDNN
// We must re-obtain the reference to debug_options because the previous
// reference is invalidated by the pipeline passes that run before this
// lambda.
const DebugOptions& debug_options = module->config().debug_options();
if ((debug_options.xla_cpu_use_onednn() ||
debug_options.xla_cpu_experimental_onednn_custom_call()) &&
Expand Down Expand Up @@ -964,6 +971,10 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
const auto& debug_options = module->config().debug_options();
const bool use_fusion_emitters = debug_options.xla_cpu_use_fusion_emitters();
bool flatten_after_fusion = options::FlattenAfterFusion(module->config());
if (debug_options.xla_cpu_opt_preset() ==
xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE) {
flatten_after_fusion = true;
}
HloPassPipeline pipeline("HLO passes after layout assignment");

{
Expand Down Expand Up @@ -2101,21 +2112,35 @@ absl::StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
options.cpu_target_config->cpu_target_machine_options.value();
}

const bool fast_compile =
module->config().debug_options().xla_cpu_opt_preset() ==
xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE;

llvm::CodeGenOptLevel opt_level =
IrCompiler::GetCodeGenOptLevel(module->config());
if (fast_compile &&
!module->config().debug_options().has_xla_backend_optimization_level()) {
opt_level = llvm::CodeGenOptLevel::Less;
}

const bool disable_expensive_passes =
module->config().debug_options().xla_llvm_disable_expensive_passes();

// Options for compiling LLVM IR to machine code.
IrCompiler::Options ir_compiler_options{
/*optimization_level=*/IrCompiler::GetCodeGenOptLevel(module->config()),
/*optimize_for_size=*/options::OptimizeForSizeRequested(module->config()),
/*optimization_level=*/opt_level,
/*optimize_for_size=*/
options::OptimizeForSizeRequested(module->config()),
/*target_machine_options=*/
target_machine_options,
/*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(module->config()),
/*disable_expensive_passes=*/
module->config().debug_options().xla_llvm_disable_expensive_passes(),
/*slp_vectorizer_disabled=*/
/*disable_expensive_passes=*/disable_expensive_passes,
/*disable_slp_vectorizer=*/
options::SlpVectorizerDisabled(module->config()),
/*disable_loop_unrolling=*/
options::DisableLoopUnrolling(module->config()),
/*disable_platform_dependent_math=*/
options::DisablePlatformDependentMath(module->config()),
options::DisablePlatformDependentMath(module->config()) || fast_compile,
};

ThunkEmitter::Options thunk_emitter_options = {
Expand Down Expand Up @@ -2244,20 +2269,32 @@ CpuCompiler::CompileAheadOfTimeThunks(
triple.normalize(), target_machine->getTargetCPU(),
target_machine->getTargetFeatureString());

const bool fast_compile =
module->config().debug_options().xla_cpu_opt_preset() ==
xla::DebugOptions::CPU_OPT_PRESET_FAST_COMPILE;

llvm::CodeGenOptLevel opt_level = target_machine->getOptLevel();
if (fast_compile &&
!module->config().debug_options().has_xla_backend_optimization_level()) {
opt_level = llvm::CodeGenOptLevel::Less;
}

const bool disable_expensive_passes =
module->config().debug_options().xla_llvm_disable_expensive_passes();

IrCompiler::Options ir_compiler_options = {
/*optimization_level=*/target_machine->getOptLevel(),
/*optimization_level=*/opt_level,
/*optimize_for_size=*/
options::OptimizeForSizeRequested(module->config()),
/*target_machine_options=*/target_machine_options,
/*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(module->config()),
/*disable_expensive_passes=*/
module->config().debug_options().xla_llvm_disable_expensive_passes(),
/*disable_expensive_passes=*/disable_expensive_passes,
/*disable_slp_vectorizer=*/
options::SlpVectorizerDisabled(module->config()),
/*disable_loop_unrolling=*/
options::DisableLoopUnrolling(module->config()),
/*disable_platform_dependent_math=*/
options::DisablePlatformDependentMath(module->config()),
options::DisablePlatformDependentMath(module->config()) || fast_compile,
/*dfsan_enabled=*/aot_options.sanitize_dataflow(),
/*dfsan_abilists_enabled=*/aot_options.sanitize_abilists_dataflow()};

Expand Down
12 changes: 11 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ message DebugOptions {
//--------------------------------------------------------------------------//
// XLA:CPU options.
//--------------------------------------------------------------------------//

enum CpuOptPreset {
CPU_OPT_PRESET_DEFAULT = 0;
CPU_OPT_PRESET_FAST_RUNTIME = 1;
CPU_OPT_PRESET_FAST_COMPILE = 2;
}

// clang-format off
// go/keep-sorted start newline_separated=yes skip_lines=1 ignore_prefixes=["optional bool","optional int32","optional string", "optional XnnGraphFusionMode", "repeated LibraryFusionType", "optional CpuSchedulerType"]
// clang-format on
Expand Down Expand Up @@ -208,6 +215,9 @@ message DebugOptions {
XNN_GRAPH_FUSION_MODE_BYPASS_COST_MODEL = 3;
}

// XLA:CPU optimization preset.
optional CpuOptPreset xla_cpu_opt_preset = 466;

// The number of seconds to wait before terminating a rendezvous call
optional int32 xla_cpu_collective_call_terminate_timeout_seconds = 417;

Expand Down Expand Up @@ -1443,7 +1453,7 @@ message DebugOptions {
// Note: when adding a new flag, please add it to one of the hardware-specific
// or hardware-agnostic sections at the top of this proto message.

// Next id: 466
// Next id: 467

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down
Loading