diff --git a/.references/spirv-tools b/.references/spirv-tools index eddf1de..ffd197f 100644 --- a/.references/spirv-tools +++ b/.references/spirv-tools @@ -1 +1 @@ -1c69c17 +3ac12f1 diff --git a/spirv-tools/source/val/validate.cpp b/spirv-tools/source/val/validate.cpp index 1141968..ecc6b80 100644 --- a/spirv-tools/source/val/validate.cpp +++ b/spirv-tools/source/val/validate.cpp @@ -390,6 +390,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState( if (auto error = AtomicsPass(*vstate, &instruction)) return error; if (auto error = PrimitivesPass(*vstate, &instruction)) return error; if (auto error = BarriersPass(*vstate, &instruction)) return error; + if (auto error = DotProductPass(*vstate, &instruction)) return error; if (auto error = GroupPass(*vstate, &instruction)) return error; // Device-Side Enqueue // Pipe diff --git a/spirv-tools/source/val/validate.h b/spirv-tools/source/val/validate.h index 1c2328e..10025d9 100644 --- a/spirv-tools/source/val/validate.h +++ b/spirv-tools/source/val/validate.h @@ -180,6 +180,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of barrier instructions. spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst); +/// Validates correctness of DotProduct instructions. +spv_result_t DotProductPass(ValidationState_t& _, const Instruction* inst); + /// Validates correctness of Group (Kernel) instructions. spv_result_t GroupPass(ValidationState_t& _, const Instruction* inst); diff --git a/spirv-tools/source/val/validate_dot_product.cpp b/spirv-tools/source/val/validate_dot_product.cpp new file mode 100644 index 0000000..21cc4be --- /dev/null +++ b/spirv-tools/source/val/validate_dot_product.cpp @@ -0,0 +1,188 @@ +// Copyright (c) 2026 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "source/val/instruction.h" +#include "source/val/validate.h" +#include "source/val/validate_scopes.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateSameSignedDot(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_id = inst->type_id(); + if (!_.IsIntScalarType(result_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result must be an int scalar type."; + } + + const spv::Op opcode = inst->opcode(); + const bool has_accumulator = opcode == spv::Op::OpSDotAccSat || + opcode == spv::Op::OpUDotAccSat || + opcode == spv::Op::OpSUDotAccSat; + if (has_accumulator) { + const uint32_t accumulator_type = _.GetOperandTypeId(inst, 4); + if (accumulator_type != result_id) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result must be the same as the Accumulator type."; + } + } + + if (opcode == spv::Op::OpUDot || opcode == spv::Op::OpUDotAccSat) { + if (!_.IsIntScalarTypeWithSignedness(result_id, 0)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result must be an unsigned int scalar type."; + } + } + + const uint32_t vec_1_id = _.GetOperandTypeId(inst, 2); + const uint32_t vec_2_id = _.GetOperandTypeId(inst, 3); + + const bool is_vec_1_scalar = _.IsIntScalarType(vec_1_id, 32); + const bool is_vec_2_scalar = _.IsIntScalarType(vec_2_id, 32); + if (is_vec_1_scalar != is_vec_2_scalar) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "'Vector 1' and 'Vector 2' must be the same type."; + } else if (is_vec_1_scalar && is_vec_2_scalar) { + // If both are scalar, spec doesn't say Signedness needs to match + const uint32_t vec_1_width = _.GetBitWidth(vec_1_id); + const uint32_t vec_2_width = _.GetBitWidth(vec_2_id); + if (vec_1_width != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 1' to be 32-bit when a scalar."; + } else if (vec_2_width != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 2' to be 32-bit when a scalar."; + } + + // When packed, the result can be 8-bit + const uint32_t result_width = _.GetBitWidth(result_id); + if (result_width < 8) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result width (" << result_width + << ") must be greater than or equal to the packed vector width of " + "8"; + } + + // PackedVectorFormat4x8Bit is used when the "Vector" operand are really + // scalar + const uint32_t packed_operand = has_accumulator ? 6 : 5; + const bool has_packed_vec_format = + inst->operands().size() == packed_operand; + if (!has_packed_vec_format) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "'Vector 1' and 'Vector 2' are a 32-bit int scalar, but no " + "Packed Vector " + "Format was provided."; + } + } else { + // both should be vectors + + if (!_.IsVectorType(vec_1_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 1' to be an int scalar or vector."; + } else if (!_.IsVectorType(vec_2_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 2' to be an int scalar or vector."; + } + + const uint32_t vec_1_length = _.GetDimension(vec_1_id); + const uint32_t vec_2_length = _.GetDimension(vec_2_id); + // If using OpTypeVectorIdEXT with a spec constant, this can be evaluated + // when spec constants are frozen + if (vec_1_length != 0 && vec_2_length != 0 && + vec_1_length != vec_2_length) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "'Vector 1' is " << vec_1_length + << " components but 'Vector 2' is " << vec_2_length + << " components"; + } + + const uint32_t vec_1_type = _.GetComponentType(vec_1_id); + const uint32_t vec_2_type = _.GetComponentType(vec_2_id); + if (!_.IsIntScalarType(vec_1_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 1' to be a vector of integers."; + } else if (!_.IsIntScalarType(vec_2_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 2' to be a vector of integers."; + } + + const uint32_t vec_1_width = _.GetBitWidth(vec_1_type); + const uint32_t vec_2_width = _.GetBitWidth(vec_2_type); + if (vec_1_width != vec_2_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "'Vector 1' component is " << vec_1_width + << "-bit but 'Vector 2' component is " << vec_2_width << "-bit"; + } + + const uint32_t result_width = _.GetBitWidth(result_id); + if (result_width < vec_1_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result width (" << result_width + << ") must be greater than or equal to the vectors width (" + << vec_1_width << ")."; + } + + if (opcode == spv::Op::OpUDot || opcode == spv::Op::OpUDotAccSat) { + const bool vec_1_unsigned = + _.IsIntScalarTypeWithSignedness(vec_1_type, 0); + const bool vec_2_unsigned = + _.IsIntScalarTypeWithSignedness(vec_2_type, 0); + if (!vec_1_unsigned) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 1' to be an vector of unsigned integers."; + } else if (!vec_2_unsigned) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 2' to be an vector of unsigned integers."; + } + } else if (opcode == spv::Op::OpSUDot || opcode == spv::Op::OpSUDotAccSat) { + const bool vec_2_unsigned = + _.IsIntScalarTypeWithSignedness(vec_2_type, 0); + if (!vec_2_unsigned) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 'Vector 2' to be an vector of unsigned integers."; + } + } + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t DotProductPass(ValidationState_t& _, const Instruction* inst) { + const spv::Op opcode = inst->opcode(); + + switch (opcode) { + case spv::Op::OpSDot: + case spv::Op::OpUDot: + case spv::Op::OpSUDot: + case spv::Op::OpSDotAccSat: + case spv::Op::OpUDotAccSat: + case spv::Op::OpSUDotAccSat: + return ValidateSameSignedDot(_, inst); + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/spirv-tools/spirv-tools/build-version.inc b/spirv-tools/spirv-tools/build-version.inc index 30edf7a..6f3e367 100644 --- a/spirv-tools/spirv-tools/build-version.inc +++ b/spirv-tools/spirv-tools/build-version.inc @@ -1 +1 @@ -"v2026.2-dev", "SPIRV-Tools v2026.2-dev v2026.1-9-g1c69c179" +"v2026.2-dev", "SPIRV-Tools v2026.2-dev v2026.1-10-g3ac12f1e"