diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 726fbe3694cbf..1788bee26820a 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -1435,6 +1435,40 @@ absl::Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { operand_shape.ToString(true)); } } + + bool memory_space_is_compatible = [&]() { + if (!opts_.layout_sensitive) { + return true; + } + if (!operand_shape.has_layout() || !output_shape.has_layout()) { + return true; + } + auto is_constant = [](const HloInstruction* instruction) { + const HloInstruction* inst = instruction; + while (inst->opcode() == HloOpcode::kCopy) { + inst = inst->operand(0); + } + return inst->opcode() == HloOpcode::kConstant; + }; + if (is_constant(bitcast->operand(0))) { + return true; + } + bool operand_has_host_memory_space = + operand_shape.layout().memory_space() == Layout::kHostMemorySpace; + bool output_has_host_memory_space = + output_shape.layout().memory_space() == Layout::kHostMemorySpace; + return operand_has_host_memory_space == output_has_host_memory_space; + }(); + + if (!memory_space_is_compatible) { + return Internal( + "%s: Bitcast cannot have different memory spaces of output (%d) and " + "operand " + "(%d) (%s) (%s)", + bitcast->ToString(), output_shape.layout().memory_space(), + operand_shape.layout().memory_space(), output_shape.ToString(true), + operand_shape.ToString(true)); + } return absl::OkStatus(); }