diff --git a/include/NZSL/Ast/Compare.inl b/include/NZSL/Ast/Compare.inl index 15f9cccb..275b7517 100644 --- a/include/NZSL/Ast/Compare.inl +++ b/include/NZSL/Ast/Compare.inl @@ -834,6 +834,12 @@ namespace nzsl::Ast bool Compare(const ScopedStatement& lhs, const ScopedStatement& rhs, const ComparisonParams& params) { + if (!Compare(lhs.targetType, rhs.targetType, params)) + return false; + + if (!Compare(lhs.targetVersion, rhs.targetVersion, params)) + return false; + if (!Compare(lhs.statement, rhs.statement, params)) return false; diff --git a/include/NZSL/Ast/Enums.hpp b/include/NZSL/Ast/Enums.hpp index ecdecdc0..30b8534b 100644 --- a/include/NZSL/Ast/Enums.hpp +++ b/include/NZSL/Ast/Enums.hpp @@ -26,7 +26,7 @@ namespace nzsl::Ast enum class AttributeType { - // Next free ID: 20 + // Next free ID: 21 AutoBinding = 17, //< Incremental binding index (external block only) Author = 12, //< Module author (module statement only) - has argument version string Binding = 0, //< Binding (external var only) - has argument index @@ -47,6 +47,7 @@ namespace nzsl::Ast Tag = 16, //< Tag (external block and external var only) - has argument string Unroll = 11, //< Unroll (for/for each only) - has argument mode Workgroup = 18, //< Work group size (function only) - has arguments X, Y, Z + Target = 20, //< Mark a scope as target specific - has arguments target string, version (optional) }; enum class BinaryType @@ -268,6 +269,13 @@ namespace nzsl::Ast Minus = 1, //< -v Plus = 2, //< +v }; + + enum class TargetType + { + GLSL = 0, + GLES = 1, + SPIRV = 3, + }; } #endif // NZSL_AST_ENUMS_HPP diff --git a/include/NZSL/Ast/Nodes.hpp b/include/NZSL/Ast/Nodes.hpp index c2b45500..0b6c2878 100644 --- a/include/NZSL/Ast/Nodes.hpp +++ b/include/NZSL/Ast/Nodes.hpp @@ -526,6 +526,8 @@ namespace nzsl::Ast NodeType GetType() const override; void Visit(StatementVisitor& visitor) override; + ExpressionValue targetType; + ExpressionValue targetVersion; StatementPtr statement; }; diff --git a/include/NZSL/Lang/ErrorList.hpp b/include/NZSL/Lang/ErrorList.hpp index 2c90368c..778440e5 100644 --- a/include/NZSL/Lang/ErrorList.hpp +++ b/include/NZSL/Lang/ErrorList.hpp @@ -39,6 +39,7 @@ NZSL_SHADERLANG_PARSER_ERROR(AttributeInvalidParameter, "invalid parameter {} fo NZSL_SHADERLANG_PARSER_ERROR(AttributeMultipleUnique, "attribute {} can only be present once", Ast::AttributeType) NZSL_SHADERLANG_PARSER_ERROR(AttributeParameterIdentifier, "attribute {} parameter can only be an identifier", Ast::AttributeType) NZSL_SHADERLANG_PARSER_ERROR(AttributeUnexpectedParameterCount, "attribute {} expects {} arguments, got {}", Ast::AttributeType, std::size_t, std::size_t) +NZSL_SHADERLANG_PARSER_ERROR(AttributeInvalidTargetVersion, "invalid target version {} for target {}", std::int32_t, std::string) NZSL_SHADERLANG_PARSER_ERROR(ExpectedToken, "expected token {}, got {}", TokenType, TokenType) NZSL_SHADERLANG_PARSER_ERROR(DuplicateIdentifier, "duplicate identifier") NZSL_SHADERLANG_PARSER_ERROR(DuplicateModule, "duplicate module") diff --git a/include/NZSL/LangWriter.hpp b/include/NZSL/LangWriter.hpp index 69c3e650..81564769 100644 --- a/include/NZSL/LangWriter.hpp +++ b/include/NZSL/LangWriter.hpp @@ -58,6 +58,7 @@ namespace nzsl struct TagAttribute; struct UnrollAttribute; struct WorkgroupAttribute; + struct TargetAttribute; void Append(const Ast::AliasType& type); void Append(const Ast::ArrayType& type); @@ -104,6 +105,7 @@ namespace nzsl void AppendAttribute(TagAttribute attribute); void AppendAttribute(UnrollAttribute attribute); void AppendAttribute(WorkgroupAttribute attribute); + void AppendAttribute(TargetAttribute attribute); void AppendComment(std::string_view section); void AppendCommentSection(std::string_view section); void AppendHeader(); diff --git a/include/NZSL/Parser.hpp b/include/NZSL/Parser.hpp index d6401855..f4fafb2e 100644 --- a/include/NZSL/Parser.hpp +++ b/include/NZSL/Parser.hpp @@ -31,6 +31,7 @@ namespace nzsl static std::string_view ToString(Ast::LoopUnroll loopUnroll); static std::string_view ToString(Ast::MemoryLayout memoryLayout); static std::string_view ToString(Ast::ModuleFeature moduleFeature); + static std::string_view ToString(Ast::TargetType targetType); static std::string_view ToString(ShaderStageType shaderStage); private: @@ -74,7 +75,7 @@ namespace nzsl Ast::StatementPtr ParseReturnStatement(); Ast::StatementPtr ParseRootStatement(std::vector attributes = {}); Ast::StatementPtr ParseSingleStatement(); - Ast::StatementPtr ParseStatement(); + Ast::StatementPtr ParseStatement(std::vector attributes = {}); std::vector ParseStatementList(SourceLocation* sourceLocation); Ast::StatementPtr ParseStructDeclaration(std::vector attributes = {}); Ast::StatementPtr ParseVariableDeclaration(); diff --git a/src/NZSL/Ast/AstSerializer.cpp b/src/NZSL/Ast/AstSerializer.cpp index de70fe22..149c9ea8 100644 --- a/src/NZSL/Ast/AstSerializer.cpp +++ b/src/NZSL/Ast/AstSerializer.cpp @@ -504,6 +504,8 @@ namespace nzsl::Ast void SerializerBase::Serialize(ScopedStatement& node) { + ExprValue(node.targetType); + ExprValue(node.targetVersion); Node(node.statement); } diff --git a/src/NZSL/Ast/Cloner.cpp b/src/NZSL/Ast/Cloner.cpp index 5b564e79..f0bdd164 100644 --- a/src/NZSL/Ast/Cloner.cpp +++ b/src/NZSL/Ast/Cloner.cpp @@ -332,6 +332,8 @@ namespace nzsl::Ast StatementPtr Cloner::Clone(ScopedStatement& node) { auto clone = std::make_unique(); + clone->targetType = Clone(node.targetType); + clone->targetVersion = Clone(node.targetVersion); clone->statement = CloneStatement(node.statement); clone->sourceLocation = node.sourceLocation; diff --git a/src/NZSL/GlslWriter.cpp b/src/NZSL/GlslWriter.cpp index ecef344f..9f850627 100644 --- a/src/NZSL/GlslWriter.cpp +++ b/src/NZSL/GlslWriter.cpp @@ -2703,6 +2703,20 @@ namespace nzsl void GlslWriter::Visit(Ast::ScopedStatement& node) { + if (node.targetType.IsResultingValue()) + { + unsigned int glVersion = m_environment.glMajorVersion * 100 + m_environment.glMinorVersion * 10; + auto targetType = node.targetType.GetResultingValue(); + std::uint32_t targetVersion = 0; + if (node.targetVersion.IsResultingValue()) + targetVersion = node.targetVersion.GetResultingValue(); + + const auto isGLSL = !m_environment.glES && targetType == Ast::TargetType::GLSL; + const auto isGLES = m_environment.glES && targetType == Ast::TargetType::GLES; + if (!((isGLSL || isGLES) && targetVersion <= glVersion)) + return; + } + EnterScope(); node.statement->Visit(*this); LeaveScope(true); diff --git a/src/NZSL/Lang/Errors.cpp b/src/NZSL/Lang/Errors.cpp index 4d3b5411..1ed315cf 100644 --- a/src/NZSL/Lang/Errors.cpp +++ b/src/NZSL/Lang/Errors.cpp @@ -84,6 +84,19 @@ struct fmt::formatter : formatter } }; +template<> +struct fmt::formatter : formatter +{ + template + auto format(const nzsl::Ast::TargetType& p, FormatContext& ctx) const -> decltype(ctx.out()) + { + auto it = nzsl::LangData::s_targets.find(p); + assert(it != nzsl::LangData::s_targets.end()); + + return formatter::format(it->second.identifier, ctx); + } +}; + namespace nzsl { std::string_view ToString(ErrorCategory errorCategory) diff --git a/src/NZSL/Lang/LangData.hpp b/src/NZSL/Lang/LangData.hpp index 56f7e6c8..d45da15a 100644 --- a/src/NZSL/Lang/LangData.hpp +++ b/src/NZSL/Lang/LangData.hpp @@ -43,7 +43,8 @@ namespace nzsl::LangData { Ast::AttributeType::Set, { "set" } }, { Ast::AttributeType::Tag, { "tag" } }, { Ast::AttributeType::Unroll, { "unroll" } }, - { Ast::AttributeType::Workgroup, { "workgroup" } } + { Ast::AttributeType::Workgroup, { "workgroup" } }, + { Ast::AttributeType::Target, { "target" } } }); struct BuiltinData @@ -259,6 +260,17 @@ namespace nzsl::LangData { Ast::LoopUnroll::Hint, { "hint" } }, { Ast::LoopUnroll::Never, { "never" } } }); + + struct TargetData + { + std::string_view identifier; + }; + + constexpr auto s_targets = frozen::make_unordered_map({ + { Ast::TargetType::GLSL, { "glsl" } }, + { Ast::TargetType::GLES, { "gles" } }, + { Ast::TargetType::SPIRV, { "spirv" } }, + }); } #endif // NZSL_LANG_LANGDATA_HPP diff --git a/src/NZSL/LangWriter.cpp b/src/NZSL/LangWriter.cpp index e1151ff9..ee00e44a 100644 --- a/src/NZSL/LangWriter.cpp +++ b/src/NZSL/LangWriter.cpp @@ -173,6 +173,14 @@ namespace nzsl bool HasValue() const { return workgroup.HasValue(); } }; + struct LangWriter::TargetAttribute + { + const Ast::ExpressionValue& target; + const Ast::ExpressionValue& version; + + bool HasValue() const { return target.HasValue(); } + }; + struct LangWriter::State { struct Identifier @@ -780,6 +788,32 @@ namespace nzsl Append(")"); } + void LangWriter::AppendAttribute(TargetAttribute attribute) + { + if (!attribute.HasValue()) + return; + + Append("target("); + + if (attribute.target.IsResultingValue()) + Append(Parser::ToString(attribute.target.GetResultingValue())); + else + attribute.target.GetExpression()->Visit(*this); + + if (attribute.version.HasValue()) + { + if (attribute.version.IsResultingValue()) + { + Append(", "); + Append(attribute.version.GetResultingValue()); + } + else + attribute.version.GetExpression()->Visit(*this); + } + + Append(")"); + } + void LangWriter::AppendComment(std::string_view section) { std::size_t lineFeed = section.find('\n'); @@ -1690,6 +1724,9 @@ namespace nzsl void LangWriter::Visit(Ast::ScopedStatement& node) { + AppendAttributes(true, + TargetAttribute{ node.targetType, node.targetVersion }); + EnterScope(); node.statement->Visit(*this); LeaveScope(); diff --git a/src/NZSL/Parser.cpp b/src/NZSL/Parser.cpp index a9dd2702..6860d355 100644 --- a/src/NZSL/Parser.cpp +++ b/src/NZSL/Parser.cpp @@ -82,6 +82,7 @@ namespace nzsl constexpr auto s_layoutMapping = BuildIdentifierMapping(LangData::s_memoryLayouts); constexpr auto s_moduleFeatureMapping = BuildIdentifierMapping(LangData::s_moduleFeatures); constexpr auto s_unrollModeMapping = BuildIdentifierMapping(LangData::s_unrollModes); + constexpr auto s_targetMapping = BuildIdentifierMapping(LangData::s_targets); } Ast::ModulePtr Parser::Parse(const std::vector& tokens) @@ -177,6 +178,14 @@ namespace nzsl return it->second.identifier; } + std::string_view Parser::ToString(Ast::TargetType targetType) + { + auto it = LangData::s_targets.find(targetType); + assert(it != LangData::s_targets.end()); + + return it->second.identifier; + } + std::string_view Parser::ToString(ShaderStageType shaderStage) { auto it = LangData::s_entryPoints.find(shaderStage); @@ -1294,6 +1303,13 @@ namespace nzsl attributes.clear(); break; + case TokenType::OpenCurlyBracket: + if (attributes.empty()) + throw ParserUnexpectedTokenError{ token.location, token.type }; + + statement = ParseStatement(std::move(attributes)); + attributes.clear(); + break; default: throw ParserUnexpectedTokenError{ token.location, token.type }; } @@ -1303,14 +1319,72 @@ namespace nzsl return statement; } - Ast::StatementPtr Parser::ParseStatement() + Ast::StatementPtr Parser::ParseStatement(std::vector attributes) { + NAZARA_USE_ANONYMOUS_NAMESPACE + if (Peek().type == TokenType::OpenCurlyBracket) { auto multiStatement = ShaderBuilder::MultiStatement(); multiStatement->statements = ParseStatementList(&multiStatement->sourceLocation); - return ShaderBuilder::Scoped(std::move(multiStatement)); + auto scopedStatement = ShaderBuilder::Scoped(std::move(multiStatement)); + for (auto&& attribute : attributes) + { + switch (attribute.type) + { + case Ast::AttributeType::Target: + { + if (scopedStatement->targetType.HasValue()) + throw ParserAttributeMultipleUniqueError{ attribute.sourceLocation, attribute.type }; + + if (attribute.args.empty()) + throw ParserAttributeUnexpectedParameterCountError{ attribute.sourceLocation, attribute.type, 1, attribute.args.size() }; + + const auto& targetTypeArg = attribute.args[0]; + + if (attribute.args[0]->GetType() != Ast::NodeType::IdentifierExpression) + throw ParserAttributeParameterIdentifierError{ targetTypeArg->sourceLocation, attribute.type }; + + bool hasExplicitTargetVersion = attribute.args.size() == 2; + + auto targetTypeStr = static_cast(*targetTypeArg).identifier; + auto it = s_targetMapping.find(targetTypeStr); + if (it == s_targetMapping.end()) + throw ParserAttributeInvalidParameterError{ targetTypeArg->sourceLocation, targetTypeStr, attribute.type }; + + scopedStatement->targetType = it->second; + + if (hasExplicitTargetVersion) + { + const auto& targetVersionArg = attribute.args[1]; + + if (targetVersionArg->GetType() != Ast::NodeType::ConstantValueExpression) + throw ParserAttributeParameterIdentifierError{ targetVersionArg->sourceLocation, attribute.type }; + + auto targetVersionValue = static_cast(*targetVersionArg).value; + if (std::holds_alternative(targetVersionValue)) + { + auto targetVersion = std::get(targetVersionValue); + if (targetVersion < 0) + throw ParserAttributeInvalidTargetVersionError{ targetVersionArg->sourceLocation, targetVersion, targetTypeStr }; + + scopedStatement->targetVersion = Nz::SafeCast(targetVersion); + } + else if (std::holds_alternative(targetVersionValue)) + scopedStatement->targetVersion = std::get(targetVersionValue); + else + throw ParserAttributeInvalidTargetVersionError{ targetVersionArg->sourceLocation, 0, targetTypeStr }; + } + + break; + } + default: + throw ParserUnexpectedAttributeError{ attribute.sourceLocation, attribute.type, "scoped statement" }; + } + } + + return scopedStatement; } else return ParseSingleStatement(); diff --git a/src/NZSL/SpirV/SpirvAstVisitor.cpp b/src/NZSL/SpirV/SpirvAstVisitor.cpp index bfbf3ce2..fc487500 100644 --- a/src/NZSL/SpirV/SpirvAstVisitor.cpp +++ b/src/NZSL/SpirV/SpirvAstVisitor.cpp @@ -1031,6 +1031,19 @@ namespace nzsl void SpirvAstVisitor::Visit(Ast::ScopedStatement& node) { + if (node.targetType.IsResultingValue()) + { + unsigned int spirvVersion = m_writer.m_environment.spvMajorVersion * 100 + m_writer.m_environment.spvMinorVersion * 10; + auto targetType = node.targetType.GetResultingValue(); + std::uint32_t targetVersion = 0; + if (node.targetVersion.IsResultingValue()) + targetVersion = node.targetVersion.GetResultingValue(); + + const auto isSPIRV = targetType == Ast::TargetType::SPIRV; + if (!(isSPIRV && targetVersion <= spirvVersion)) + return; + } + node.statement->Visit(*this); }