diff --git a/.references/spirv b/.references/spirv index 6aebf95..e153333 100644 --- a/.references/spirv +++ b/.references/spirv @@ -1 +1 @@ -04f10f6 +f31ca17 diff --git a/.references/spirv-tools b/.references/spirv-tools index bbd0e3f..67eb38b 100644 --- a/.references/spirv-tools +++ b/.references/spirv-tools @@ -1 +1 @@ -a66a95e +d9d2ec1 diff --git a/spirv-tools/source/opt/folding_rules.cpp b/spirv-tools/source/opt/folding_rules.cpp index a20d904..0a5525c 100644 --- a/spirv-tools/source/opt/folding_rules.cpp +++ b/spirv-tools/source/opt/folding_rules.cpp @@ -2603,6 +2603,138 @@ FoldingRule RedundantLogicalNot() { }; } +// Cases handled: +// ((a ? C0 : C1) == C2) = ((a ? (C0 == C2) : (C1 == C2)) +// ((a ? C0 : C1) != C2) = ((a ? (C0 != C2) : (C1 != C2)) +// ((a ? C0 : C1) < C2) = ((a ? (C0 < C2) : (C1 < C2)) +// ((a ? C0 : C1) <= C2) = ((a ? (C0 <= C2) : (C1 <= C2)) +// ((a ? C0 : C1) > C2) = ((a ? (C0 > C2) : (C1 > C2)) +// ((a ? C0 : C1) >= C2) = ((a ? (C0 >= C2) : (C1 >= C2)) +// ((a ? C0 : C1) || C2) = ((a ? (C0 || C2) : (C1 || C2)) +// ((a ? C0 : C1) && C2) = ((a ? (C0 && C2) : (C1 && C2)) +// ((a ? C0 : C1) + C2) = ((a ? (C0 + C2) : (C1 + C2)) +// ((a ? C0 : C1) - C2) = ((a ? (C0 - C2) : (C1 - C2)) +// ((a ? C0 : C1) * C2) = ((a ? (C0 * C2) : (C1 * C2)) +// ((a ? C0 : C1) / C2) = ((a ? (C0 / C2) : (C1 / C2)) +// ((a ? C0 : C1) >> C2) = ((a ? (C0 >> C2) : (C1 >> C2)) +// ((a ? C0 : C1) << C2) = ((a ? (C0 << C2) : (C1 << C2)) +// ((a ? C0 : C1) ^ C2) = ((a ? (C0 ^ C2) : (C1 ^ C2)) +// ((a ? C0 : C1) | C2) = ((a ? (C0 | C2) : (C1 | C2)) +// ((a ? C0 : C1) & C2) = ((a ? (C0 & C2) : (C1 & C2)) +static const constexpr spv::Op MergeBinaryOpSelectOps[] = { + spv::Op::OpLogicalEqual, + spv::Op::OpLogicalNotEqual, + spv::Op::OpLogicalAnd, + spv::Op::OpLogicalOr, + spv::Op::OpIEqual, + spv::Op::OpINotEqual, + spv::Op::OpUGreaterThan, + spv::Op::OpSGreaterThan, + spv::Op::OpUGreaterThanEqual, + spv::Op::OpSGreaterThanEqual, + spv::Op::OpULessThan, + spv::Op::OpSLessThan, + spv::Op::OpULessThanEqual, + spv::Op::OpSLessThanEqual, + spv::Op::OpFOrdEqual, + spv::Op::OpFUnordEqual, + spv::Op::OpFOrdNotEqual, + spv::Op::OpFUnordNotEqual, + spv::Op::OpFOrdLessThan, + spv::Op::OpFUnordLessThan, + spv::Op::OpFOrdGreaterThan, + spv::Op::OpFUnordGreaterThan, + spv::Op::OpFOrdLessThanEqual, + spv::Op::OpFUnordLessThanEqual, + spv::Op::OpFOrdGreaterThanEqual, + spv::Op::OpFUnordGreaterThanEqual, + spv::Op::OpIAdd, + spv::Op::OpFAdd, + spv::Op::OpISub, + spv::Op::OpFSub, + spv::Op::OpIMul, + spv::Op::OpFMul, + spv::Op::OpUDiv, + spv::Op::OpSDiv, + spv::Op::OpFDiv, + spv::Op::OpVectorTimesScalar, + spv::Op::OpShiftRightLogical, + spv::Op::OpShiftRightArithmetic, + spv::Op::OpShiftLeftLogical, + spv::Op::OpBitwiseXor, + spv::Op::OpBitwiseOr, + spv::Op::OpBitwiseAnd}; + +FoldingRule MergeBinaryOpSelect(spv::Op opcode) { + assert(std::find(std::begin(MergeBinaryOpSelectOps), + std::end(MergeBinaryOpSelectOps), + opcode) != std::end(MergeBinaryOpSelectOps) && + "Wrong opcode."); + + return [opcode](IRContext* context, Instruction* inst, + const std::vector& constants) { + const analysis::Constant* const_input = ConstInput(constants); + if (!const_input) { + return false; + } + Instruction* non_const = NonConstInput(context, constants[0], inst); + if (non_const->opcode() != spv::Op::OpSelect) { + return false; + } + std::vector select_constants = + context->get_constant_mgr()->GetOperandConstants(non_const); + if (!select_constants[1] || !select_constants[2]) { + return false; + } + + InstructionBuilder ir_builder( + context, inst, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + Instruction *lhs, *rhs; + if (constants[0]) { + lhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, + inst->GetSingleWordInOperand(0), + non_const->GetSingleWordInOperand(1)); + rhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, + inst->GetSingleWordInOperand(0), + non_const->GetSingleWordInOperand(2)); + } else { + lhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, + non_const->GetSingleWordInOperand(1), + inst->GetSingleWordInOperand(1)); + rhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, + non_const->GetSingleWordInOperand(2), + inst->GetSingleWordInOperand(1)); + } + + if (!lhs || !rhs) { + return false; + } + + if (context->get_instruction_folder().FoldInstruction(lhs)) { + context->AnalyzeDefUse(lhs); + while (lhs->opcode() == spv::Op::OpCopyObject) { + lhs = + context->get_def_use_mgr()->GetDef(lhs->GetSingleWordInOperand(0)); + } + } + if (context->get_instruction_folder().FoldInstruction(rhs)) { + context->AnalyzeDefUse(rhs); + while (rhs->opcode() == spv::Op::OpCopyObject) { + rhs = + context->get_def_use_mgr()->GetDef(rhs->GetSingleWordInOperand(0)); + } + } + inst->SetOpcode(spv::Op::OpSelect); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {non_const->GetSingleWordInOperand(0)}}, + {SPV_OPERAND_TYPE_ID, {lhs->result_id()}}, + {SPV_OPERAND_TYPE_ID, {rhs->result_id()}}}); + return true; + }; +} + // Fold OpLogicalNot instructions that follow a comparison, // if the comparison is only used by that instruction. // @@ -2721,6 +2853,45 @@ FoldingRule FoldLogicalNotComparison() { }; } +// (a == true) = a +// (a == false) = !a +// (a != true) = !a +// (a != false) = a +FoldingRule RedundantLogicalEqual() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpLogicalEqual || + inst->opcode() == spv::Op::OpLogicalNotEqual); + + const analysis::Constant* const_input = ConstInput(constants); + if (!const_input) { + return false; + } + + analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); + if (inst->type_id() != + def_mgr->GetDef(inst->GetSingleWordInOperand(0))->type_id()) { + return false; + } + + std::optional uniform_const = GetBoolConstantKind(const_input); + if (!uniform_const) { + return false; + } + + bool direct_copy = inst->opcode() == spv::Op::OpLogicalEqual + ? uniform_const.value() + : !uniform_const.value(); + + inst->SetOpcode(direct_copy ? spv::Op::OpCopyObject + : spv::Op::OpLogicalNot); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, + {NonConstInput(context, constants[0], inst)->result_id()}}}); + return true; + }; +} + enum class FloatConstantKind { Unknown, Zero, One }; FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { @@ -3699,6 +3870,8 @@ void FoldingRules::AddFoldingRules() { rules_[op].push_back(RedundantBinaryLhs0To0(op)); for (auto op : ReassociateCommutiveBitwiseOps) rules_[op].push_back(ReassociateCommutiveBitwise(op)); + for (auto op : MergeBinaryOpSelectOps) + rules_[op].push_back(MergeBinaryOpSelect(op)); rules_[spv::Op::OpSDiv].push_back(RedundantSUDiv()); rules_[spv::Op::OpUDiv].push_back(RedundantSUDiv()); rules_[spv::Op::OpSMod].push_back(RedundantSUMod()); @@ -3797,6 +3970,9 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpLogicalNot].push_back(RedundantLogicalNot()); rules_[spv::Op::OpLogicalNot].push_back(FoldLogicalNotComparison()); + rules_[spv::Op::OpLogicalEqual].push_back(RedundantLogicalEqual()); + rules_[spv::Op::OpLogicalNotEqual].push_back(RedundantLogicalEqual()); + rules_[spv::Op::OpStore].push_back(StoringUndef()); rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle()); diff --git a/spirv-tools/spirv-tools/build-version.inc b/spirv-tools/spirv-tools/build-version.inc index e6c6e66..91ed412 100644 --- a/spirv-tools/spirv-tools/build-version.inc +++ b/spirv-tools/spirv-tools/build-version.inc @@ -1 +1 @@ -"v2026.2-dev", "SPIRV-Tools v2026.2-dev v2026.1-13-ga66a95ee" +"v2026.2-dev", "SPIRV-Tools v2026.2-dev v2026.1-16-gd9d2ec12"