diff --git a/.gdb_history b/.gdb_history new file mode 100644 index 00000000..3b663efb --- /dev/null +++ b/.gdb_history @@ -0,0 +1,15 @@ +b nzsl::WgslWriter::Visit(Ast::AccessIndexExpression&) +b nzsl::WgslWriter::Visit(nzsl::Ast::AccessIndexExpression&) +run /tmp/tes/main.nzsl --compile=wgsl -o /tmp/tes/ +p m_currentState->std140EmulationState +n +n +p m_currentState->std140EmulationState +n +n +n +p m_currentState->std140EmulationState +q +run ~/Downloads/tes/main.nzsl --compile=wgsl -o @stdout +bt +q diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 36130923..f9c0bafa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -164,7 +164,7 @@ jobs: # Setup compilation mode and install project dependencies - name: Configure xmake and install dependencies - run: xmake config --plat=${{ matrix.confs.plat }} --arch=${{ matrix.confs.arch }} --kind=${{ matrix.kind }} --mode=${{ matrix.confs.mode }} ${{ env.ADDITIONAL_CONF }} --ccache=n --yes + run: xmake config -vD --plat=${{ matrix.confs.plat }} --arch=${{ matrix.confs.arch }} --kind=${{ matrix.kind }} --mode=${{ matrix.confs.mode }} ${{ env.ADDITIONAL_CONF }} --ccache=n --yes # Save dependencies - name: Save cached xmake dependencies diff --git a/README.md b/README.md index 6974ce01..f5723ba1 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # Nazara Shading Language (NZSL) -NZSL is a shader language inspired by Rust and C++ which compiles to GLSL or SPIR-V (without depending on SPIRV-Cross). +NZSL is a shader language inspired by Rust and C++ which compiles to GLSL, WGSL or SPIR-V (without depending on SPIRV-Cross). ### Why a new shader language? @@ -48,9 +48,9 @@ fn main(input: VertOut) -> FragOut You can find precompiled binaries in the [releases](https://github.com/NazaraEngine/ShaderLang/releases). -NZSL is designed to be embedded in a game engine / game / graphics application that uses GLSL / SPIR-V for its shaders. +NZSL is designed to be embedded in a game engine / game / graphics application that uses GLSL / WGSL / SPIR-V for its shaders. -You can use it to generate GLSL, GLSL ES and SPIR-V in two non-exclusive ways: +You can use it to generate GLSL, GLSL ES, WGSL and SPIR-V in two non-exclusive ways: 1) Using the offline NZSL compiler (nzslc) ahead of time, in a way similar to glslang or glslc today. 2) Use NZSL as a library in your application to compile shaders in a dynamic way, just as they're needed (which can be used to benefit from supported extensions to improve generation). @@ -58,15 +58,16 @@ You can use it to generate GLSL, GLSL ES and SPIR-V in two non-exclusive ways: ### Offline compilation There are two binary tools you can use: -- **nzslc**: shader compiler, for compiling nzsl files to binary nzsl or directly to GLSL/SPIR-V. +- **nzslc**: shader compiler, for compiling nzsl files to binary nzsl or directly to GLSL/WGSL/SPIR-V. - **nzsla**: shader archiver, store and compress all your compiled shaders in a single file. **nzslc example usage:** - Validating shader: `nzslc file.nzsl` -- Compile a shader to GLSL: `nzslc --compile=glsl file.nzsl` -- Compile a shader to SPIR-V: `nzslc --compile=spv file.nzsl` -- Compile a shader using modules to both GLSL and SPIR-V header includable version: `nzslc --module module_file.nzsl --module module_folder/ --compile=glsl-header,spv-header file.nzsl` +- Compile a shader to GLSL: `nzsl --compile=glsl file.nzsl` +- Compile a shader to WGSL: `nzsl --compile=wgsl file.nzsl` +- Compile a shader to SPIR-V: `nzsl --compile=spv file.nzsl` +- Compile a shader using modules to GLSL, WGSL and SPIR-V header includable version: `nzsl --module module_file.nzsl --module module_folder/ --compile=glsl-header,wgsl-header,spv-header file.nzsl` Run `nzslc -h` to see all supported options. @@ -86,6 +87,7 @@ Run `nzsla -h` to see all supported options. #include #include #include +#include int main() { @@ -98,10 +100,14 @@ int main() nzsl::GlslWriter glslWriter; nzsl::GlslWriter::Output output = glslWriter.Generate(shaderAst); // output.code contains GLSL that can directly be used by OpenGL + + nzsl::WgslWriter wgslWriter; + nzsl::WgslWriter::Output output = wgslWriter.Generate(shaderAst); + // output.code contains WGSL that can directly be used by WebGPU (or any native implementation) } ``` -The library contains a lot of options to customize the generation process (target SPIR-V/GLSL version, GLSL ES, gl_Position.y flipping, gl_Position.z remapping to match Vulkan semantics, supported OpenGL extensions, etc.). +The library contains a lot of options to customize the generation process (target SPIR-V/GLSL version, GLSL ES, gl_Position.y flipping, gl_Position.z remapping to match Vulkan semantics, supported OpenGL extensions, supported WebGPU features, etc.). ## Integration @@ -136,13 +142,11 @@ At one of my previous working place we were using huge HLSL-derived shaders with NZSL is designed to be small, fast and easy to debug, for example NZSL to GLSL retains a lot of the source code information which could be lost during SSA (SPIR-V) translation, even with debug symbols enabled. -## Is there a DXIL/WGSL backend? - -Not yet, as I don't target Direct3D or WebGPU yet. - -DXIL is not very different from SPIR-V and WGSL looks a lot like NZSL so it should be quite easy to add, though. +## Is there a DXIL backend? -See [this issue](https://github.com/NazaraEngine/ShaderLang/issues/13) for WGSL. +Not yet, as I don't target Direct3D yet.\ +DXIL is not very different from SPIR-V so it should be quite easy to add, though.\ +Note that [Shader Model 7 will accept SPIR-V](https://devblogs.microsoft.com/directx/directx-adopting-spir-v/) so NZSL will be usable with Direct3D. ## Are there limitations? diff --git a/include/CNZSL/CNZSL.h b/include/CNZSL/CNZSL.h index 7c3efa4d..eb2b0784 100644 --- a/include/CNZSL/CNZSL.h +++ b/include/CNZSL/CNZSL.h @@ -16,5 +16,6 @@ #include #include #include +#include #endif /* CNZSL_CNZSL_H */ diff --git a/include/CNZSL/Config.h b/include/CNZSL/Config.h index f2adfd79..784abde1 100644 --- a/include/CNZSL/Config.h +++ b/include/CNZSL/Config.h @@ -3,6 +3,7 @@ Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) 2024 REMqb (remqb at remqb dot fr) + 2025 kbz_8 (contact@kbz8.me) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/include/CNZSL/WgslWriter.h b/include/CNZSL/WgslWriter.h new file mode 100644 index 00000000..71242d94 --- /dev/null +++ b/include/CNZSL/WgslWriter.h @@ -0,0 +1,71 @@ +/* + Copyright (C) 2025 kbz_8 ( contact@kbz8.me ) + This file is part of the "Nazara Shading Wgsluage - C Binding" project + For conditions of distribution and use, see copyright notice in Config.hpp +*/ + +#pragma once + +#ifndef CNZSL_WGSLWRITER_H +#define CNZSL_WGSLWRITER_H + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" +{ +#endif + +typedef struct nzslWgslWriter nzslWgslWriter; +typedef struct nzslWgslOutput nzslWgslOutput; +typedef int (*nzslWgslWriterFeatureSupportCallback)(const char*); + +typedef struct +{ + nzslWgslWriterFeatureSupportCallback featuresCallback; +} nzslWgslWriterEnvironment; + +CNZSL_API nzslWgslWriter* nzslWgslWriterCreate(void); +CNZSL_API void nzslWgslWriterDestroy(nzslWgslWriter* writerPtr); + +CNZSL_API nzslWgslOutput* nzslWgslWriterGenerate(nzslWgslWriter* writerPtr, nzslModule* modulePtr, const nzslBackendParameters* backendParametersPtr); + +/** + * Gets the last error message set by the last operation to this writer + * + * @param writerPtr + * @returns null-terminated error string + */ +CNZSL_API const char* nzslWgslWriterGetLastError(const nzslWgslWriter* writerPtr); + +CNZSL_API void nzslWgslWriterSetEnv(nzslWgslWriter* writerPtr, const nzslWgslWriterEnvironment* env); + +CNZSL_API void nzslWgslOutputDestroy(nzslWgslOutput* outputPtr); +CNZSL_API const char* nzslWgslOutputGetCode(const nzslWgslOutput* outputPtr, size_t* length); + +/** + * As Wgsl does not support combined image samplers, those have to be + * splitted into texture and samplers, shifting the bindings in a given set. + * + * Returns the new binding assigned to a NZSL binding in a given set. + * + * @param output + * @param bindingName + * @return new binding or 0 if not found + */ +CNZSL_API unsigned int nzslWgslOutputGetBindingRemap(const nzslWgslOutput* outputPtr, unsigned int set, unsigned int binding); + +CNZSL_API int nzslWgslOutputGetUsesDrawParameterBaseInstanceUniform(const nzslWgslOutput* outputPtr); +CNZSL_API int nzslWgslOutputGetUsesDrawParameterBaseVertexUniform(const nzslWgslOutput* outputPtr); +CNZSL_API int nzslWgslOutputGetUsesDrawParameterDrawIndexUniform(const nzslWgslOutput* outputPtr); + +#ifdef __cplusplus +} +#endif + +#endif /* CNZSL_WGSLWRITER_H */ + diff --git a/include/NZSL/Ast/Cloner.hpp b/include/NZSL/Ast/Cloner.hpp index fd70c33c..0e7d9281 100644 --- a/include/NZSL/Ast/Cloner.hpp +++ b/include/NZSL/Ast/Cloner.hpp @@ -32,6 +32,8 @@ namespace nzsl::Ast ModulePtr Clone(const Module& module); StatementPtr Clone(const Statement& statement); + StructDescription Clone(const StructDescription& desc); + Cloner& operator=(const Cloner&) = delete; Cloner& operator=(Cloner&&) = delete; @@ -101,6 +103,7 @@ namespace nzsl::Ast inline ExpressionPtr Clone(const Expression& node); inline ModulePtr Clone(const Module& module); inline StatementPtr Clone(const Statement& node); + inline StructDescription Clone(const StructDescription& desc); } #include diff --git a/include/NZSL/Ast/Cloner.inl b/include/NZSL/Ast/Cloner.inl index b79db391..c14a059a 100644 --- a/include/NZSL/Ast/Cloner.inl +++ b/include/NZSL/Ast/Cloner.inl @@ -65,4 +65,10 @@ namespace nzsl::Ast Cloner cloner; return cloner.Clone(node); } + + inline StructDescription Clone(const StructDescription& desc) + { + Cloner cloner; + return cloner.Clone(desc); + } } diff --git a/include/NZSL/Ast/Transformations/Std140EmulationTransformer.hpp b/include/NZSL/Ast/Transformations/Std140EmulationTransformer.hpp new file mode 100644 index 00000000..25e0c420 --- /dev/null +++ b/include/NZSL/Ast/Transformations/Std140EmulationTransformer.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_STD140EMULATION_HPP +#define NZSL_AST_TRANSFORMATIONS_STD140EMULATION_HPP + +#include +#include + +namespace nzsl::Ast +{ + class NZSL_API Std140EmulationTransformer final : public Transformer + { + public: + struct Options; + + Std140EmulationTransformer() = default; + + inline bool Transform(Module& module, TransformerContext& context, std::string* error = nullptr); + bool Transform(Module& module, TransformerContext& context, const Options& options, std::string* error = nullptr); + + struct Options + { + }; + + private: + using Transformer::Transform; + + ExpressionTransformation Transform(AccessFieldExpression&& accessFieldExpr) override; + ExpressionTransformation Transform(AccessIndexExpression&& accessIndexExpr) override; + StatementTransformation Transform(DeclareStructStatement&& declStruct) override; + + DeclareStructStatementPtr DeclareStride16PrimitiveHelper(PrimitiveType type, std::size_t moduleIndex, SourceLocation sourceLocation); + bool ComputeStructDeclarationPadding(StructDescription& desc, const SourceLocation& sourceLocation) const; + FieldOffsets ComputeStructFieldOffsets(const StructDescription& desc, const SourceLocation& sourceLocation) const; + bool HandleStd140Propagation(MultiStatementPtr& multiStatement, std::size_t structIndex, SourceLocation sourceLocation, bool shouldExport); + + std::unordered_map m_structStd140Map; + std::unordered_map m_stride16Structs; + const Options* m_options; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_STD140EMULATION_HPP + diff --git a/include/NZSL/Ast/Transformations/Std140EmulationTransformer.inl b/include/NZSL/Ast/Transformations/Std140EmulationTransformer.inl new file mode 100644 index 00000000..e69de29b diff --git a/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp b/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp index 201a7006..d73bc2d4 100644 --- a/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp +++ b/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp @@ -24,14 +24,21 @@ namespace nzsl::Ast struct Options { bool removeScalarSwizzling = false; + bool removeSwizzleAssigment = false; }; private: using Transformer::Transform; ExpressionTransformation Transform(SwizzleExpression&& swizzle) override; + ExpressionTransformation Transform(AssignExpression&& assign) override; + void PushAssignment(AssignExpression* assign) noexcept; + void PopAssignment() noexcept; + + std::vector m_assignmentStack; const Options* m_options; + bool m_inAssignmentLhs = false; }; } diff --git a/include/NZSL/Ast/Transformations/UniformStructToStd140.hpp b/include/NZSL/Ast/Transformations/UniformStructToStd140.hpp new file mode 100644 index 00000000..6d68a048 --- /dev/null +++ b/include/NZSL/Ast/Transformations/UniformStructToStd140.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_UNIFORMSTRUCTTOSTD140_HPP +#define NZSL_AST_TRANSFORMATIONS_UNIFORMSTRUCTTOSTD140_HPP + +#include + +namespace nzsl::Ast +{ + class NZSL_API UniformStructToStd140Transformer final : public Transformer + { + public: + struct Options; + + UniformStructToStd140Transformer() = default; + + inline bool Transform(Module& module, TransformerContext& context, std::string* error = nullptr); + bool Transform(Module& module, TransformerContext& context, const Options& options, std::string* error = nullptr); + + struct Options + { + bool cloneStructIfUsedElsewhere = true; + }; + + private: + using Transformer::Transform; + + StatementTransformation Transform(DeclareExternalStatement&& declExternal) override; + StatementTransformation Transform(DeclareStructStatement&& declStruct) override; + + const Options* m_options; + std::unordered_map m_structRemap; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_UNIFORMSTRUCTTOSTD140_HPP diff --git a/include/NZSL/Ast/Transformations/UniformStructToStd140.inl b/include/NZSL/Ast/Transformations/UniformStructToStd140.inl new file mode 100644 index 00000000..6de37a84 --- /dev/null +++ b/include/NZSL/Ast/Transformations/UniformStructToStd140.inl @@ -0,0 +1,11 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ + inline bool UniformStructToStd140Transformer::Transform(Module& module, TransformerContext& context, std::string* error) + { + return Transform(module, context, {}, error); + } +} diff --git a/include/NZSL/WgslWriter.hpp b/include/NZSL/WgslWriter.hpp new file mode 100644 index 00000000..c91d79f5 --- /dev/null +++ b/include/NZSL/WgslWriter.hpp @@ -0,0 +1,203 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_WGSLWRITER_HPP +#define NZSL_WGSLWRITER_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace nzsl +{ + class NZSL_API WgslWriter : Ast::ExpressionVisitorExcept, Ast::StatementVisitorExcept + { + public: + using FeaturesSupportCallback = std::function; + struct Environment; + struct Output; + + inline WgslWriter(); + WgslWriter(const WgslWriter&) = delete; + WgslWriter(WgslWriter&&) = delete; + ~WgslWriter() = default; + + Output Generate(Ast::Module& module, const BackendParameters& parameters = {}); + + void SetEnv(Environment environment); + + struct Environment + { + FeaturesSupportCallback featuresCallback; + }; + + struct Output + { + std::string code; + std::unordered_map bindingRemap; + bool usesDrawParameterBaseInstanceUniform; + bool usesDrawParameterBaseVertexUniform; + bool usesDrawParameterDrawIndexUniform; + }; + + static void RegisterPasses(Ast::TransformerExecutor& executor); + + private: + struct PreVisitor; + friend PreVisitor; + + enum class IntrinsicHelper + { + Infinity, + MatrixInverse, + NaN, + }; + + // Attributes + struct AutoBindingAttribute; + struct AuthorAttribute; + struct BindingAttribute; + struct BuiltinAttribute; + struct CondAttribute; + struct DepthWriteAttribute; + struct DescriptionAttribute; + struct EarlyFragmentTestsAttribute; + struct EntryAttribute; + struct FeatureAttribute; + struct InterpAttribute; + struct LicenseAttribute; + struct LocationAttribute; + struct SetAttribute; + struct TagAttribute; + struct UnrollAttribute; + struct WorkgroupAttribute; + + void Append(const Ast::AliasType& type); + void Append(const Ast::ArrayType& type); + void Append(const Ast::DynArrayType& type); + void Append(const Ast::ExpressionType& type); + void Append(const Ast::ExpressionValue& type); + void Append(const Ast::FunctionType& functionType); + void Append(const Ast::ImplicitArrayType& type); + void Append(const Ast::ImplicitMatrixType& type); + void Append(const Ast::ImplicitVectorType& type); + void Append(const Ast::IntrinsicFunctionType& intrinsicFunctionType); + void Append(const Ast::MatrixType& matrixType); + void Append(const Ast::MethodType& methodType); + void Append(const Ast::ModuleType& moduleType); + void Append(const Ast::NamedExternalBlockType& namedExternalBlockType); + void Append(Ast::NoType); + void Append(Ast::PrimitiveType type); + void Append(const Ast::PushConstantType& pushConstantType); + void Append(const Ast::SamplerType& samplerType); + void Append(const Ast::StorageType& storageType); + void Append(const Ast::StructType& structType); + void Append(const Ast::TextureType& samplerType); + void Append(const Ast::Type& type); + void Append(const Ast::UniformType& uniformType); + void Append(const Ast::VectorType& vecType); + template void Append(const T& param); + template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); + template void AppendAttributes(bool appendLine, Args&&... params); + template void AppendAttributesInternal(bool& first, const T& param); + template void AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params); + void AppendAttribute(bool first, AutoBindingAttribute attribute); + void AppendAttribute(bool first, AuthorAttribute attribute); + void AppendAttribute(bool first, BindingAttribute attribute); + void AppendAttribute(bool first, BuiltinAttribute attribute); + void AppendAttribute(bool first, CondAttribute attribute); + void AppendAttribute(bool first, DepthWriteAttribute attribute); + void AppendAttribute(bool first, DescriptionAttribute attribute); + void AppendAttribute(bool first, EarlyFragmentTestsAttribute attribute); + void AppendAttribute(bool first, EntryAttribute attribute); + void AppendAttribute(bool first, FeatureAttribute attribute); + void AppendAttribute(bool first, InterpAttribute attribute); + void AppendAttribute(bool first, LicenseAttribute attribute); + void AppendAttribute(bool first, LocationAttribute attribute); + void AppendAttribute(bool first, SetAttribute attribute); + void AppendAttribute(bool first, TagAttribute attribute); + void AppendAttribute(bool first, UnrollAttribute attribute); + void AppendAttribute(bool first, WorkgroupAttribute attribute); + void AppendComment(std::string_view section); + void AppendCommentSection(std::string_view section); + void AppendIntrinsicHelpers(IntrinsicHelper helper, const Ast::ExpressionType& type); + void AppendHeader(const Ast::Module::Metadata& metadata); + template void AppendIdentifier(const T& map, std::size_t id, bool append_module_prefix = false); + void AppendLine(std::string_view txt = {}); + template void AppendLine(Args&&... params); + void AppendModuleAttributes(const Ast::Module::Metadata& metadata); + void AppendStatementList(std::vector& statements); + template void AppendValue(const T& value); + + void EnterScope(); + void LeaveScope(bool skipLine = true); + + void RegisterAlias(std::size_t aliasIndex, std::string aliasName); + void RegisterConstant(std::size_t constantIndex, std::string constantName); + void RegisterFunction(std::size_t funcIndex, std::string functionName); + void RegisterModule(std::size_t moduleIndex, std::string moduleName); + void RegisterStruct(std::size_t structIndex, const Ast::StructDescription& structDescription); + void RegisterVariable(std::size_t varIndex, std::string varName, bool isInout = false); + + void ScopeVisit(Ast::Statement& node); + + void Visit(Ast::ExpressionPtr& expr, bool encloseIfRequired = false); + + using ExpressionVisitorExcept::Visit; + void Visit(Ast::AccessFieldExpression& node) override; + void Visit(Ast::AccessIdentifierExpression& node) override; + void Visit(Ast::AccessIndexExpression& node) override; + void Visit(Ast::IdentifierValueExpression& node) override; + void Visit(Ast::AssignExpression& node) override; + void Visit(Ast::BinaryExpression& node) override; + void Visit(Ast::CallFunctionExpression& node) override; + void Visit(Ast::CastExpression& node) override; + void Visit(Ast::ConditionalExpression& node) override; + void Visit(Ast::ConstantArrayValueExpression& node) override; + void Visit(Ast::ConstantValueExpression& node) override; + void Visit(Ast::IdentifierExpression& node) override; + void Visit(Ast::IntrinsicExpression& node) override; + void Visit(Ast::SwizzleExpression& node) override; + void Visit(Ast::TypeConstantExpression& node) override; + void Visit(Ast::UnaryExpression& node) override; + + using StatementVisitorExcept::Visit; + void Visit(Ast::BranchStatement& node) override; + void Visit(Ast::BreakStatement& node) override; + void Visit(Ast::ConditionalStatement& node) override; + void Visit(Ast::ContinueStatement& node) override; + void Visit(Ast::DeclareAliasStatement& node) override; + void Visit(Ast::DeclareConstStatement& node) override; + void Visit(Ast::DeclareExternalStatement& node) override; + void Visit(Ast::DeclareFunctionStatement& node) override; + void Visit(Ast::DeclareOptionStatement& node) override; + void Visit(Ast::DeclareStructStatement& node) override; + void Visit(Ast::DeclareVariableStatement& node) override; + void Visit(Ast::DiscardStatement& node) override; + void Visit(Ast::ExpressionStatement& node) override; + void Visit(Ast::ForStatement& node) override; + void Visit(Ast::ForEachStatement& node) override; + void Visit(Ast::ImportStatement& node) override; + void Visit(Ast::MultiStatement& node) override; + void Visit(Ast::NoOpStatement& node) override; + void Visit(Ast::ReturnStatement& node) override; + void Visit(Ast::ScopedStatement& node) override; + void Visit(Ast::WhileStatement& node) override; + + struct State; + + Environment m_environment; + State* m_currentState; + }; +} + +#include + +#endif // NZSL_LANGWRITER_HPP diff --git a/include/NZSL/WgslWriter.inl b/include/NZSL/WgslWriter.inl new file mode 100644 index 00000000..517ced86 --- /dev/null +++ b/include/NZSL/WgslWriter.inl @@ -0,0 +1,13 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include + +namespace nzsl +{ + inline WgslWriter::WgslWriter() : + m_currentState(nullptr) + { + } +} diff --git a/src/CNZSL/Structs/WgslOutput.hpp b/src/CNZSL/Structs/WgslOutput.hpp new file mode 100644 index 00000000..bf8eecb3 --- /dev/null +++ b/src/CNZSL/Structs/WgslOutput.hpp @@ -0,0 +1,16 @@ +// Copyright (C) 2025 2025 kbz_8 ( contact@kbz8.me ) +// This file is part of the "Nazara Shading Language - C Binding" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef CNZSL_STRUCTS_WGSLOUTPUT_HPP +#define CNZSL_STRUCTS_WGSLOUTPUT_HPP + +#include + +struct nzslWgslOutput : nzsl::WgslWriter::Output +{ +}; + +#endif // CNZSL_STRUCTS_WGSLOUTPUT_HPP diff --git a/src/CNZSL/Structs/WgslWriter.hpp b/src/CNZSL/Structs/WgslWriter.hpp new file mode 100644 index 00000000..d4344cdc --- /dev/null +++ b/src/CNZSL/Structs/WgslWriter.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2025 2025 kbz_8 ( contact@kbz8.me ) +// This file is part of the "Nazara Shading Language - C Binding" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef CNZSL_STRUCTS_WGSLWRITER_HPP +#define CNZSL_STRUCTS_WGSLWRITER_HPP + +#include +#include +#include + +struct nzslWgslWriter +{ + std::string lastError; + nzsl::WgslWriter writer; +}; + +#endif // CNZSL_STRUCTS_WGSLWRITER_HPP diff --git a/src/CNZSL/WgslWriter.cpp b/src/CNZSL/WgslWriter.cpp new file mode 100644 index 00000000..a3bd258d --- /dev/null +++ b/src/CNZSL/WgslWriter.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2025 kbz_8 ( contact@kbz8.me ) +// This file is part of the "Nazara Shading Wgsluage - C Binding" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" +{ + CNZSL_API nzslWgslWriter* nzslWgslWriterCreate(void) + { + return new nzslWgslWriter; + } + + CNZSL_API void nzslWgslWriterDestroy(nzslWgslWriter* writerPtr) + { + delete writerPtr; + } + + CNZSL_API nzslWgslOutput* nzslWgslWriterGenerate(nzslWgslWriter* writerPtr, nzslModule* modulePtr, const nzslBackendParameters* backendParametersPtr) + { + try + { + nzsl::BackendParameters parameters; + if (backendParametersPtr) + parameters = static_cast(*backendParametersPtr); + + std::unique_ptr output = std::make_unique(); + static_cast(*output) = writerPtr->writer.Generate(*modulePtr->module, parameters); + + return output.release(); + } + catch (std::exception& e) + { + writerPtr->lastError = fmt::format("nzslWgslWriterGenerate failed: {}", e.what()); + return nullptr; + } + catch (...) + { + writerPtr->lastError = "nzslWgslWriterGenerate failed with unknown error"; + return nullptr; + } + } + + CNZSL_API const char* nzslWgslWriterGetLastError(const nzslWgslWriter* writerPtr) + { + return writerPtr->lastError.c_str(); + } + + CNZSL_API void nzslWgslWriterSetEnv(nzslWgslWriter* writerPtr, const nzslWgslWriterEnvironment* env) + { + nzsl::WgslWriter::Environment writerEnv; + writerEnv.featuresCallback = [=](std::string_view feature) -> bool + { + return env->featuresCallback(feature.data()); + }; + + writerPtr->writer.SetEnv(writerEnv); + } + + CNZSL_API void nzslWgslOutputDestroy(nzslWgslOutput* outputPtr) + { + delete outputPtr; + } + + CNZSL_API const char* nzslWgslOutputGetCode(const nzslWgslOutput* outputPtr, size_t* length) + { + if (length) + *length = outputPtr->code.size(); + + return outputPtr->code.data(); + } + + CNZSL_API unsigned int nzslWgslOutputGetBindingRemap(const nzslWgslOutput* outputPtr, unsigned int set, unsigned int binding) + { + auto it = outputPtr->bindingRemap.find((static_cast(set) << 32 | binding)); + return (it == outputPtr->bindingRemap.end() ? 0 : it->second); + } + + CNZSL_API int nzslWgslOutputGetUsesDrawParameterBaseInstanceUniform(const nzslWgslOutput* outputPtr) + { + return outputPtr->usesDrawParameterBaseInstanceUniform; + } + + CNZSL_API int nzslWgslOutputGetUsesDrawParameterBaseVertexUniform(const nzslWgslOutput* outputPtr) + { + return outputPtr->usesDrawParameterBaseVertexUniform; + } + + CNZSL_API int nzslWgslOutputGetUsesDrawParameterDrawIndexUniform(const nzslWgslOutput* outputPtr) + { + return outputPtr->usesDrawParameterDrawIndexUniform; + } +} + diff --git a/src/NZSL/Ast/Cloner.cpp b/src/NZSL/Ast/Cloner.cpp index 08007fce..ec73c2c5 100644 --- a/src/NZSL/Ast/Cloner.cpp +++ b/src/NZSL/Ast/Cloner.cpp @@ -50,6 +50,30 @@ namespace nzsl::Ast return PopStatement(); } + StructDescription Cloner::Clone(const StructDescription& desc) + { + StructDescription clone; + clone.layout = Clone(desc.layout); + clone.name = desc.name; + clone.tag = desc.tag; + + clone.members.reserve(desc.members.size()); + for (const auto& member : desc.members) + { + auto& cloneMember = clone.members.emplace_back(); + cloneMember.name = member.name; + cloneMember.type = Clone(member.type); + cloneMember.builtin = Clone(member.builtin); + cloneMember.cond = Clone(member.cond); + cloneMember.interp = Clone(member.interp); + cloneMember.locationIndex = Clone(member.locationIndex); + + cloneMember.sourceLocation = member.sourceLocation; + cloneMember.tag = member.tag; + } + return clone; + } + ExpressionValue Cloner::CloneType(const ExpressionValue& exprType) { if (!exprType.HasValue()) @@ -215,26 +239,7 @@ namespace nzsl::Ast auto clone = std::make_unique(); clone->isExported = Clone(node.isExported); clone->structIndex = node.structIndex; - - clone->description.layout = Clone(node.description.layout); - clone->description.name = node.description.name; - clone->description.tag = node.description.tag; - - clone->description.members.reserve(node.description.members.size()); - for (const auto& member : node.description.members) - { - auto& cloneMember = clone->description.members.emplace_back(); - cloneMember.name = member.name; - cloneMember.type = Clone(member.type); - cloneMember.builtin = Clone(member.builtin); - cloneMember.cond = Clone(member.cond); - cloneMember.interp = Clone(member.interp); - cloneMember.locationIndex = Clone(member.locationIndex); - - cloneMember.sourceLocation = member.sourceLocation; - cloneMember.tag = member.tag; - } - + clone->description = Clone(node.description); clone->sourceLocation = node.sourceLocation; return clone; diff --git a/src/NZSL/Ast/RecursiveVisitor.cpp b/src/NZSL/Ast/RecursiveVisitor.cpp index 8468564d..91ddb5bd 100644 --- a/src/NZSL/Ast/RecursiveVisitor.cpp +++ b/src/NZSL/Ast/RecursiveVisitor.cpp @@ -37,10 +37,10 @@ namespace nzsl::Ast void RecursiveVisitor::Visit(CallFunctionExpression& node) { + node.targetFunction->Visit(*this); + for (auto& param : node.parameters) param.expr->Visit(*this); - - node.targetFunction->Visit(*this); } void RecursiveVisitor::Visit(CallMethodExpression& node) @@ -59,6 +59,7 @@ namespace nzsl::Ast void RecursiveVisitor::Visit(ConditionalExpression& node) { + node.condition->Visit(*this); node.truePath->Visit(*this); node.falsePath->Visit(*this); } @@ -124,6 +125,7 @@ namespace nzsl::Ast void RecursiveVisitor::Visit(ConditionalStatement& node) { + node.condition->Visit(*this); node.statement->Visit(*this); } diff --git a/src/NZSL/Ast/Transformations/Std140EmulationTransformer.cpp b/src/NZSL/Ast/Transformations/Std140EmulationTransformer.cpp new file mode 100644 index 00000000..01752174 --- /dev/null +++ b/src/NZSL/Ast/Transformations/Std140EmulationTransformer.cpp @@ -0,0 +1,354 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nzsl::Ast +{ + constexpr std::string_view s_paddingBaseName = "_padding"; + const auto s_primitiveTypeToStructFieldType = frozen::make_unordered_map({ + { PrimitiveType::Float32, StructFieldType::Float1 }, + { PrimitiveType::Float64, StructFieldType::Float2 }, + { PrimitiveType::Int32, StructFieldType::Int1 }, + { PrimitiveType::UInt32, StructFieldType::UInt1 }, + { PrimitiveType::Boolean, StructFieldType::Bool1 }, + }); + + std::size_t DeepResolveStructIndex(const ExpressionType& exprType) + { + std::size_t structIndex; + ExpressionType resolvedExprType = ResolveAlias(exprType); + if (IsArrayType(resolvedExprType)) + structIndex = ResolveStructIndex(std::get(resolvedExprType).InnerType()); + else + structIndex = ResolveStructIndex(resolvedExprType); + return structIndex; + } + + bool Std140EmulationTransformer::Transform(Module& module, TransformerContext& context, const Options& options, std::string* error) + { + m_options = &options; + if (!TransformImportedModules(module, context, error)) + return false; + + return TransformModule(module, context, error); + } + + auto Std140EmulationTransformer::Transform(AccessFieldExpression&& accessFieldExpr) -> ExpressionTransformation + { + const ExpressionType* exprType = GetExpressionType(*accessFieldExpr.expr); + if (!exprType) + return DontVisitChildren{}; + ExpressionType resolvedExprType = ResolveAlias(*exprType); + std::size_t structIndex = DeepResolveStructIndex(resolvedExprType); + if (structIndex == std::numeric_limits::max()) + return DontVisitChildren{}; + const auto& structData = m_context->structs.Retrieve(structIndex, accessFieldExpr.sourceLocation); + + std::uint32_t remainingIndex = accessFieldExpr.fieldIndex; + for (auto it = structData.description->members.begin(); it != structData.description->members.end(); ++it) + { + if (it->cond.HasValue() && !it->cond.GetResultingValue()) + continue; + + if (remainingIndex == 0) + { + while (it != structData.description->members.end() && it->name.compare(0, s_paddingBaseName.length(), s_paddingBaseName) == 0) + { + accessFieldExpr.fieldIndex++; + ++it; + } + break; + } + + remainingIndex--; + } + + return DontVisitChildren{}; + } + + auto Std140EmulationTransformer::Transform(AccessIndexExpression&& accessIndexExpr) -> ExpressionTransformation + { + assert(accessIndexExpr.expr); + if (accessIndexExpr.expr->GetType() != NodeType::AccessFieldExpression) + return DontVisitChildren{}; + AccessFieldExpression& accessFieldExpr = *static_cast(accessIndexExpr.expr.get()); + const ExpressionType* exprType = GetExpressionType(*accessFieldExpr.expr); + if (!exprType) + return DontVisitChildren{}; + ExpressionType resolvedExprType = ResolveAlias(*exprType); + + StructDescription& desc = *m_context->structs.Retrieve(DeepResolveStructIndex(resolvedExprType), accessFieldExpr.sourceLocation).description; + if (!desc.layout.HasValue() || desc.layout.GetResultingValue() != MemoryLayout::Std140) + return DontVisitChildren{}; + + ExpressionType fieldExprType; + std::uint32_t remainingIndices = accessFieldExpr.fieldIndex; + for (const auto& member : desc.members) + { + if (member.cond.HasValue()) + { + if (!member.cond.IsResultingValue()) + return DontVisitChildren{}; //< unresolved + + if (!member.cond.GetResultingValue()) + continue; + } + if (remainingIndices == 0) + { + fieldExprType = ResolveAlias(member.type.GetResultingValue()); + break; + } + remainingIndices--; + } + + assert(IsArrayType(fieldExprType)); + auto& arrayField = std::get(fieldExprType); + if (!IsStructType(arrayField.InnerType())) + return DontVisitChildren{}; + auto& innerArrayField = std::get(arrayField.InnerType()); + auto it = std::find_if(m_stride16Structs.begin(), m_stride16Structs.end(), [&](const auto& elem){ return elem.second == innerArrayField.structIndex; }); + if (it == m_stride16Structs.end()) + return DontVisitChildren{}; + + static_cast(GetCurrentExpressionPtr().get())->cachedExpressionType = ExpressionType{ StructType{ it->second } }; + + auto finalAccessFieldExpr = std::make_unique(); + finalAccessFieldExpr->sourceLocation = accessIndexExpr.sourceLocation; + finalAccessFieldExpr->expr = std::move(GetCurrentExpressionPtr()); + finalAccessFieldExpr->fieldIndex = 0; // In stride helpers the value should always be the first field, followed by padding fields + finalAccessFieldExpr->cachedExpressionType = ExpressionType{ StructType{ it->second } }; + + return ReplaceExpression{ std::move(finalAccessFieldExpr) }; + } + + auto Std140EmulationTransformer::Transform(DeclareStructStatement&& declStruct) -> StatementTransformation + { + auto& structData = m_context->structs.Retrieve(*declStruct.structIndex, declStruct.sourceLocation); + StructDescription* desc = structData.description; + + bool shouldReplaceStatement = false; + MultiStatementPtr multiStatement = ShaderBuilder::MultiStatement(); + multiStatement->sourceLocation = declStruct.sourceLocation; + + if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != MemoryLayout::Std140) + { + if (!HandleStd140Propagation(multiStatement, *declStruct.structIndex, declStruct.sourceLocation, (declStruct.isExported.HasValue() ? declStruct.isExported.GetResultingValue() : false))) + return DontVisitChildren{}; + shouldReplaceStatement = m_structStd140Map.count(*declStruct.structIndex); + } + + for (auto& field : desc->members) + { + const ExpressionType& resolvedFieldType = ResolveAlias(field.type.GetResultingValue()); + auto handleStruct = [&](const StructType& structure) + { + if (m_structStd140Map.count(structure.structIndex)) + { + field.type = ExpressionType{ StructType{ m_structStd140Map.at(structure.structIndex) } }; + shouldReplaceStatement = true; + } + }; + + if (IsArrayType(resolvedFieldType)) + { + const auto& array = std::get(resolvedFieldType); + if (IsPrimitiveType(array.containedType->type)) + { + auto primitiveType = std::get(array.containedType->type); + if (!m_stride16Structs.count(primitiveType)) + multiStatement->statements.emplace_back(DeclareStride16PrimitiveHelper(primitiveType, structData.moduleIndex, declStruct.sourceLocation)); + array.containedType->type = ExpressionType{ StructType{ m_stride16Structs[primitiveType] } }; + shouldReplaceStatement = true; + } + else if (IsStructType(array.containedType->type)) + handleStruct(std::get(array.containedType->type)); + } + else if (IsStructType(resolvedFieldType)) + handleStruct(std::get(resolvedFieldType)); + } + + if (desc->layout.HasValue() && desc->layout.GetResultingValue() == MemoryLayout::Std140) + ComputeStructDeclarationPadding(*desc, declStruct.sourceLocation); + + if (shouldReplaceStatement) + { + multiStatement->statements.emplace_back(std::move(GetCurrentStatementPtr())); + return ReplaceStatement{ std::move(multiStatement) }; + } + return DontVisitChildren{}; + } + + DeclareStructStatementPtr Std140EmulationTransformer::DeclareStride16PrimitiveHelper(PrimitiveType type, std::size_t moduleIndex, SourceLocation sourceLocation) + { + FieldOffsets fieldOffset(nzsl::StructLayout::Packed); + fieldOffset.AddField(s_primitiveTypeToStructFieldType.at(type)); + + StructDescription::StructMember member; + member.type = ExpressionValue{ ExpressionType{ type } }; + member.sourceLocation = sourceLocation; + member.name = "value"; + + StructDescription desc; + desc.name = fmt::format("{}_stride16", ToString(type, sourceLocation)); + desc.members.push_back(std::move(member)); + ComputeStructDeclarationPadding(desc, sourceLocation); + + auto structStatement = ShaderBuilder::DeclareStruct(std::move(desc), ExpressionValue{ false }); + structStatement->sourceLocation = sourceLocation; + + TransformerContext::StructData structData; + structData.description = &structStatement->description; + structData.moduleIndex = moduleIndex; + structStatement->structIndex = m_context->structs.Register(structData, std::nullopt, sourceLocation); + + m_stride16Structs[type] = *structStatement->structIndex; + return structStatement; + } + + bool Std140EmulationTransformer::ComputeStructDeclarationPadding(StructDescription& desc, const SourceLocation& sourceLocation) const + { + bool descriptionChanged = false; + std::size_t paddingFieldIndex = 0; + FieldOffsets fieldOffsets(StructLayout::Packed); + + auto appendPaddingField = [&](std::size_t fieldIndex) + { + StructDescription::StructMember member; + member.type = ExpressionType{ PrimitiveType::Float32 }; + member.sourceLocation = sourceLocation; + member.name = fmt::format("{}{}", s_paddingBaseName, paddingFieldIndex); + desc.members.insert(desc.members.begin() + fieldIndex, std::move(member)); + fieldOffsets.AddField(s_primitiveTypeToStructFieldType.at(PrimitiveType::Float32)); + paddingFieldIndex++; + }; + + auto fillWithPaddingFieldsUntilAlignedSize = [&](std::size_t fieldIndex, std::size_t sizeGoal) -> std::size_t + { + std::size_t i = 0; + for (; fieldOffsets.GetSize() < sizeGoal; ++i) + { + appendPaddingField(fieldIndex + i); + descriptionChanged = true; + } + return i; + }; + + // Field that have struct type or array of struct must be aligned on 16 bytes + // This loop adds padding elements before those fields if necessary + for (std::size_t i = 0; i < desc.members.size(); ++i) + { + ExpressionType resolvedFieldType = ResolveAlias(desc.members.at(i).type.GetResultingValue()); + + std::size_t structIndex = DeepResolveStructIndex(resolvedFieldType); + if (structIndex != std::numeric_limits::max()) + { + if (fieldOffsets.GetSize() % 16 != 0) + i += fillWithPaddingFieldsUntilAlignedSize(i, Nz::Align(static_cast(fieldOffsets.GetSize()), 16)); + + FieldOffsets innerFieldOffsets = ComputeStructFieldOffsets(*m_context->structs.Retrieve(structIndex, sourceLocation).description, sourceLocation); + if (IsArrayType(resolvedFieldType)) + fieldOffsets.AddStructArray(innerFieldOffsets, std::get(resolvedFieldType).length); + else + fieldOffsets.AddStruct(innerFieldOffsets); + } + else + { + FieldOffsets alignedFieldOffsets(StructLayout::Std140); + RegisterStructField(alignedFieldOffsets, resolvedFieldType); + fieldOffsets.AddStruct(alignedFieldOffsets); + } + } + + fieldOffsets = ComputeStructFieldOffsets(desc, sourceLocation); + fillWithPaddingFieldsUntilAlignedSize(desc.members.size(), Nz::Align(static_cast(fieldOffsets.GetAlignedSize()), 16)); + return descriptionChanged; + } + + FieldOffsets Std140EmulationTransformer::ComputeStructFieldOffsets(const StructDescription& desc, const SourceLocation& location) const + { + FieldOffsets innerFieldOffset(StructLayout::Packed); + auto structFinder = [&](std::size_t structIndex) -> const nzsl::FieldOffsets& + { + StructDescription* innerDesc = m_context->structs.Retrieve(structIndex, location).description; + innerFieldOffset = ComputeStructFieldOffsets(*innerDesc, location); + return innerFieldOffset; + }; + + FieldOffsets fieldOffset(StructLayout::Packed); + for (auto& field : desc.members) + { + const auto& resolvedFieldType = ResolveAlias(field.type.GetResultingValue()); + RegisterStructField(fieldOffset, resolvedFieldType, structFinder); + } + + return fieldOffset; + } + + bool Std140EmulationTransformer::HandleStd140Propagation(MultiStatementPtr& multiStatement, std::size_t structIndex, SourceLocation sourceLocation, bool shouldExport) + { + bool isUsedInStd140Struct = false; + bool isUsedInPlainCode = false; + + const auto& variables = m_context->variables; + for (const auto& [_, var] : variables.values) + { + std::size_t varStructIndex = DeepResolveStructIndex(var.type); + if (varStructIndex == std::numeric_limits::max()) + continue; + if (varStructIndex == structIndex) + isUsedInPlainCode = true; + StructDescription& varStructDesc = *m_context->structs.Retrieve(varStructIndex, sourceLocation).description; + for (const auto& member : varStructDesc.members) + { + std::size_t memberStructIndex = DeepResolveStructIndex(member.type.GetResultingValue()); + if (memberStructIndex == structIndex) + { + if (varStructDesc.layout.HasValue() && varStructDesc.layout.GetResultingValue() == MemoryLayout::Std140) + isUsedInStd140Struct = true; + else + isUsedInPlainCode = true; + } + } + + if (isUsedInStd140Struct && isUsedInPlainCode) // Skip useless iterations + break; + } + if (isUsedInStd140Struct) + { + auto& structData = m_context->structs.Retrieve(structIndex, sourceLocation); + if (isUsedInPlainCode) + { + // Cloning struct but with Std140 layout + StructDescription desc = Clone(*structData.description); + desc.layout = ExpressionValue{ MemoryLayout::Std140 }; + desc.name += "_std140"; + ComputeStructDeclarationPadding(desc, sourceLocation); + + auto newStruct = ShaderBuilder::DeclareStruct(std::move(desc), ExpressionValue{ shouldExport }); + newStruct->sourceLocation = sourceLocation; + + TransformerContext::StructData newStructData; + newStructData.description = &newStruct->description; + newStructData.moduleIndex = structData.moduleIndex; + + newStruct->structIndex = m_context->structs.Register(newStructData, std::nullopt, sourceLocation); + m_structStd140Map[structIndex] = *newStruct->structIndex; + + multiStatement->statements.emplace_back(std::move(newStruct)); + } + else + structData.description->layout = ExpressionValue{ MemoryLayout::Std140 }; + return true; + } + return false; + } +} diff --git a/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp b/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp index b042a2fd..d50f91e3 100644 --- a/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp +++ b/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp @@ -8,6 +8,18 @@ namespace nzsl::Ast { + void SwizzleTransformer::PushAssignment(AssignExpression* assign) noexcept + { + m_assignmentStack.push_back(assign); + m_inAssignmentLhs = true; + } + + void SwizzleTransformer::PopAssignment() noexcept + { + m_inAssignmentLhs = false; + m_assignmentStack.pop_back(); + } + bool SwizzleTransformer::Transform(Module& module, TransformerContext& context, const Options& options, std::string* error) { m_options = &options; @@ -55,6 +67,155 @@ namespace nzsl::Ast return ReplaceExpression{ std::move(cast) }; } + if (m_options->removeSwizzleAssigment && m_inAssignmentLhs) + { + if (!IsVectorType(*exprType)) + return VisitChildren{}; + + AssignExpression* assign = m_assignmentStack.empty() ? nullptr : m_assignmentStack.back(); + if (!assign || !assign->right) + return VisitChildren{}; + + // Flatten swizzle chain + std::array flatComponents{}; + std::size_t flatCount = swizzle.componentCount; + for (std::size_t i = 0; i < flatCount; ++i) + flatComponents[i] = swizzle.components[i]; + + ExpressionPtr baseExpr = std::move(swizzle.expression); // Take ownership as we'll replace the LHS anyway + while (baseExpr->GetType() == Ast::NodeType::SwizzleExpression) + { + SwizzleExpression* innerSwz = static_cast(baseExpr.get()); + std::array nextComponents{}; + for (std::size_t i = 0; i < flatCount; ++i) + nextComponents[i] = innerSwz->components[flatComponents[i]]; + flatComponents = nextComponents; + // Step deeper + baseExpr = std::move(innerSwz->expression); + } + + const ExpressionType* baseVecEt = GetResolvedExpressionType(*baseExpr); + if (!baseVecEt || !IsVectorType(*baseVecEt)) + return VisitChildren{}; + + const VectorType& vecType = std::get(*baseVecEt); + const std::size_t vecSize = vecType.componentCount; + const PrimitiveType baseType = vecType.type; + + // Cache LHS base vector and RHS (reused several times) + ExpressionPtr baseVec = CacheExpression(std::move(baseExpr)); + ExpressionPtr rhs = CacheExpression(std::move(assign->right)); + + // Constructor of full vector: vecN[T](...) + auto ctor = std::make_unique(); + ctor->sourceLocation = swizzle.sourceLocation; + ctor->targetType = ExpressionType{ VectorType{ vecSize, baseType } }; + ctor->cachedExpressionType = ExpressionType{ VectorType{ vecSize, baseType } }; + ctor->expressions.reserve(vecSize); + + // Map destination index to optional RHS component index + auto mapDstToRhsIndex = [&](std::size_t dst) -> std::optional + { + for (std::size_t k = 0; k < flatCount; ++k) + { + if (flatComponents[k] == dst) + return k; + } + return std::nullopt; + }; + + // Small helper to read one component from RHS + auto makeRhsComponentExpr = [&](std::size_t rhsCompIndex) -> ExpressionPtr + { + if (flatCount == 1) + return Clone(*rhs); // Scalar write + auto rhsSwz = std::make_unique(); + rhsSwz->sourceLocation = swizzle.sourceLocation; + rhsSwz->expression = Clone(*rhs); + rhsSwz->componentCount = 1; + rhsSwz->components[0] = static_cast(rhsCompIndex); + rhsSwz->cachedExpressionType = ExpressionType{ baseType }; + return rhsSwz; + }; + + // If written components form a contiguous in-order suffix, + // emit "... , rhs" directly (gives vec4[f32](vec.x, rhs) for yzw) + auto isContiguousSuffix = [&]() -> std::optional + { + // Find first written index + std::size_t minWritten = vecSize; + for (std::size_t k = 0; k < flatCount; ++k) + minWritten = std::min(minWritten, flatComponents[k]); + if (minWritten + flatCount != vecSize) + return std::nullopt; // Not a suffix length + // Check order: {min,...,vec_size-1} + for (std::size_t k = 0; k < flatCount; ++k) + { + if (flatComponents[k] != minWritten + k) + return std::nullopt; + } + return minWritten; // Suffix starts here + }; + + if (auto suffixStart = isContiguousSuffix()) + { + // Keep prefix from baseVec + for (std::size_t dst = 0; dst < *suffixStart; ++dst) + { + auto keepSwz = std::make_unique(); + keepSwz->sourceLocation = swizzle.sourceLocation; + keepSwz->expression = Clone(*baseVec); + keepSwz->componentCount = 1; + keepSwz->components[0] = static_cast(dst); + keepSwz->cachedExpressionType = ExpressionType{ baseType }; + ctor->expressions.push_back(std::move(keepSwz)); + HandleExpression(ctor->expressions.back()); + } + + // Append rhs as a single argument + ctor->expressions.push_back(Clone(*rhs)); + HandleExpression(ctor->expressions.back()); + } + else + { + // General case: per-component merge + for (std::size_t dst = 0; dst < vecSize; ++dst) + { + if (auto rhsComp = mapDstToRhsIndex(dst)) + { + ExpressionPtr fromRhs = makeRhsComponentExpr(*rhsComp); + HandleExpression(fromRhs); + ctor->expressions.push_back(std::move(fromRhs)); + } + else + { + auto keepSwz = std::make_unique(); + keepSwz->sourceLocation = swizzle.sourceLocation; + keepSwz->expression = Clone(*baseVec); + keepSwz->componentCount = 1; + keepSwz->components[0] = static_cast(dst); + keepSwz->cachedExpressionType = ExpressionType{ baseType }; + ctor->expressions.push_back(std::move(keepSwz)); + HandleExpression(ctor->expressions.back()); + } + } + } + + // vec.xyz = rhs; ==> vec = vecN[T](merged...) + assign->left = std::move(baseVec); + assign->right = std::move(ctor); + } + + return VisitChildren{}; + } + + auto SwizzleTransformer::Transform(AssignExpression&& assign) -> ExpressionTransformation + { + PushAssignment(&assign); + HandleExpression(assign.left); + PopAssignment(); + + HandleExpression(assign.right); return VisitChildren{}; } } diff --git a/src/NZSL/Ast/Transformations/UniformStructToStd140.cpp b/src/NZSL/Ast/Transformations/UniformStructToStd140.cpp new file mode 100644 index 00000000..eb1961da --- /dev/null +++ b/src/NZSL/Ast/Transformations/UniformStructToStd140.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace nzsl::Ast +{ + bool UniformStructToStd140Transformer::Transform(Module& module, TransformerContext& context, const Options& options, std::string* error) + { + m_options = &options; + if (!TransformImportedModules(module, context, error)) + return false; + + return TransformModule(module, context, error); + } + + auto UniformStructToStd140Transformer::Transform(DeclareExternalStatement&& node) -> StatementTransformation + { + for (auto& var : node.externalVars) + { + auto& varType = var.type.GetResultingValue(); + if (IsUniformType(varType)) + { + auto& uniformType = std::get(varType); + if (m_structRemap.count(uniformType.containedType.structIndex)) + uniformType.containedType.structIndex = m_structRemap.at(uniformType.containedType.structIndex); + } + } + return VisitChildren{}; + } + + auto UniformStructToStd140Transformer::Transform(DeclareStructStatement&& declStruct) -> StatementTransformation + { + if (!declStruct.structIndex.has_value()) + return VisitChildren{}; + + if (declStruct.description.layout.HasValue() && declStruct.description.layout.GetResultingValue() == MemoryLayout::Std140) + return DontVisitChildren{}; + + bool isUsedInUniformBuffer = false; + bool isUsedInPlainCode = false; + + const auto& variables = m_context->variables; + for (const auto& [_, var] : variables.values) + { + const auto& resolvedVarType = ResolveAlias(var.type); + if (IsStructType(resolvedVarType) && std::get(resolvedVarType).structIndex == *declStruct.structIndex) + isUsedInPlainCode = true; + else if (IsUniformType(resolvedVarType) && std::get(resolvedVarType).containedType.structIndex == *declStruct.structIndex) + isUsedInUniformBuffer = true; + + if (isUsedInUniformBuffer && isUsedInPlainCode) // Skip useless iterations + break; + } + + if (isUsedInUniformBuffer) + { + if (isUsedInPlainCode && m_options->cloneStructIfUsedElsewhere) + { + // Cloning struct but with Std140 layout + StructDescription desc = Clone(declStruct.description); + desc.layout = ExpressionValue{ MemoryLayout::Std140 }; + desc.name += "_std140"; + + MultiStatementPtr multiStatement = ShaderBuilder::MultiStatement(); + multiStatement->sourceLocation = declStruct.sourceLocation; + + auto newStruct = ShaderBuilder::DeclareStruct(std::move(desc), ExpressionValue{ (declStruct.isExported.HasValue() ? declStruct.isExported.GetResultingValue() : false) }); + newStruct->sourceLocation = declStruct.sourceLocation; + newStruct->structIndex = m_context->structs.RegisterNewIndex(); + + m_structRemap[*declStruct.structIndex] = *newStruct->structIndex; + + multiStatement->statements.emplace_back(std::move(newStruct)); + multiStatement->statements.emplace_back(std::move(GetCurrentStatementPtr())); + + return ReplaceStatement{ std::move(multiStatement) }; + } + else + m_context->structs.Retrieve(*declStruct.structIndex, declStruct.sourceLocation).description->layout = ExpressionValue{ MemoryLayout::Std140 }; + } + return VisitChildren{}; + } +} diff --git a/src/NZSL/GlslWriter.cpp b/src/NZSL/GlslWriter.cpp index eb650ab9..aa9fc6ea 100644 --- a/src/NZSL/GlslWriter.cpp +++ b/src/NZSL/GlslWriter.cpp @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include #include @@ -350,7 +352,6 @@ namespace nzsl }; } - struct GlslWriter::State { State(const BackendParameters& backendParameters, const GlslWriter::Parameters& glslParameters) : @@ -653,9 +654,9 @@ namespace nzsl if (m_currentState->hasDrawParametersBaseInstanceUniform) Append(s_glslWriterShaderDrawParametersBaseInstanceName); else if (!m_environment.glES && glVersion >= 460) - Append("gl_BaseInstance"); + Append("uint(gl_BaseInstance)"); else - Append("gl_BaseInstanceARB"); + Append("uint(gl_BaseInstanceARB)"); break; } @@ -664,9 +665,9 @@ namespace nzsl if (m_currentState->hasDrawParametersBaseVertexUniform) Append(s_glslWriterShaderDrawParametersBaseVertexName); else if (!m_environment.glES && glVersion >= 460) - Append("gl_BaseVertex"); + Append("uint(gl_BaseVertex)"); else - Append("gl_BaseVertexARB"); + Append("uint(gl_BaseVertexARB)"); break; } @@ -675,15 +676,15 @@ namespace nzsl if (m_currentState->hasDrawParametersDrawIndexUniform) Append(s_glslWriterShaderDrawParametersDrawIndexName); else if (!m_environment.glES && glVersion >= 460) - Append("gl_DrawID"); + Append("uint(gl_DrawID)"); else - Append("gl_DrawIDARB"); + Append("uint(gl_DrawIDARB)"); break; } case Ast::BuiltinEntry::InstanceIndex: { - Append("(", Ast::BuiltinEntry::BaseInstance, " + gl_InstanceID)"); + Append(Ast::BuiltinEntry::BaseInstance, " + uint(gl_InstanceID)"); break; } @@ -691,7 +692,16 @@ namespace nzsl { auto it = s_glslBuiltinMapping.find(builtin); assert(it != s_glslBuiltinMapping.end()); - Append(it->second.identifier); + const std::array builtinsToCast { + Ast::BuiltinEntry::LocalInvocationIndex, + Ast::BuiltinEntry::VertexIndex, + Ast::BuiltinEntry::WorkgroupCount, + Ast::BuiltinEntry::WorkgroupIndices, + }; + if (std::find(builtinsToCast.begin(), builtinsToCast.end(), it->first) != builtinsToCast.end()) + Append("uint(", it->second.identifier, ')'); + else + Append(it->second.identifier); } } } @@ -1333,13 +1343,13 @@ namespace nzsl if (m_currentState->hasDrawParametersBaseInstanceUniform || m_currentState->hasDrawParametersBaseVertexUniform || m_currentState->hasDrawParametersDrawIndexUniform) { if (m_currentState->hasDrawParametersBaseInstanceUniform) - AppendLine("uniform int ", s_glslWriterShaderDrawParametersBaseInstanceName, ";"); + AppendLine("uniform uint ", s_glslWriterShaderDrawParametersBaseInstanceName, ";"); if (m_currentState->hasDrawParametersBaseVertexUniform) - AppendLine("uniform int ", s_glslWriterShaderDrawParametersBaseVertexName, ";"); + AppendLine("uniform uint ", s_glslWriterShaderDrawParametersBaseVertexName, ";"); if (m_currentState->hasDrawParametersDrawIndexUniform) - AppendLine("uniform int ", s_glslWriterShaderDrawParametersDrawIndexName, ";"); + AppendLine("uniform uint ", s_glslWriterShaderDrawParametersDrawIndexName, ";"); AppendLine(); } @@ -2781,7 +2791,7 @@ namespace nzsl assert(node.structIndex); RegisterStruct(*node.structIndex, &node.description, structName); - // Don't output structs used for UBO/SSBO description + // Don't output structs only used for UBO/SSBO description if (m_currentState->previsitor.bufferStructs.UnboundedTest(*node.structIndex)) { if (m_currentState->backendParameters.debugLevel >= DebugLevel::Minimal) diff --git a/src/NZSL/Lang/LangData.hpp b/src/NZSL/Lang/LangData.hpp index d62a6c99..8523b0f1 100644 --- a/src/NZSL/Lang/LangData.hpp +++ b/src/NZSL/Lang/LangData.hpp @@ -54,16 +54,16 @@ namespace nzsl::LangData }; constexpr auto s_builtinData = frozen::make_unordered_map({ - { Ast::BuiltinEntry::BaseInstance, { "base_instance", ShaderStageType::Vertex, Ast::PrimitiveType::Int32 } }, - { Ast::BuiltinEntry::BaseVertex, { "base_vertex", ShaderStageType::Vertex, Ast::PrimitiveType::Int32 } }, - { Ast::BuiltinEntry::DrawIndex, { "draw_index", ShaderStageType::Vertex, Ast::PrimitiveType::Int32 } }, + { Ast::BuiltinEntry::BaseInstance, { "base_instance", ShaderStageType::Vertex, Ast::PrimitiveType::UInt32 } }, + { Ast::BuiltinEntry::BaseVertex, { "base_vertex", ShaderStageType::Vertex, Ast::PrimitiveType::UInt32 } }, + { Ast::BuiltinEntry::DrawIndex, { "draw_index", ShaderStageType::Vertex, Ast::PrimitiveType::UInt32 } }, { Ast::BuiltinEntry::FragCoord, { "frag_coord", ShaderStageType::Fragment, Ast::VectorType { 4, Ast::PrimitiveType::Float32 } } }, { Ast::BuiltinEntry::FragDepth, { "frag_depth", ShaderStageType::Fragment, Ast::PrimitiveType::Float32 } }, { Ast::BuiltinEntry::GlocalInvocationIndices, { "global_invocation_indices", ShaderStageType::Compute, Ast::VectorType { 3, Ast::PrimitiveType::UInt32 } } }, - { Ast::BuiltinEntry::InstanceIndex, { "instance_index", ShaderStageType::Vertex, Ast::PrimitiveType::Int32 } }, + { Ast::BuiltinEntry::InstanceIndex, { "instance_index", ShaderStageType::Vertex, Ast::PrimitiveType::UInt32 } }, { Ast::BuiltinEntry::LocalInvocationIndex, { "local_invocation_index", ShaderStageType::Compute, Ast::PrimitiveType::UInt32 } }, { Ast::BuiltinEntry::LocalInvocationIndices, { "local_invocation_indices", ShaderStageType::Compute, Ast::VectorType { 3, Ast::PrimitiveType::UInt32 } } }, - { Ast::BuiltinEntry::VertexIndex, { "vertex_index", ShaderStageType::Vertex, Ast::PrimitiveType::Int32 } }, + { Ast::BuiltinEntry::VertexIndex, { "vertex_index", ShaderStageType::Vertex, Ast::PrimitiveType::UInt32 } }, { Ast::BuiltinEntry::VertexPosition, { "position", ShaderStageType::Vertex, Ast::VectorType { 4, Ast::PrimitiveType::Float32 } } }, { Ast::BuiltinEntry::WorkgroupCount, { "workgroup_count", ShaderStageType::Compute, Ast::VectorType { 3, Ast::PrimitiveType::UInt32 } } }, { Ast::BuiltinEntry::WorkgroupIndices, { "workgroup_indices", ShaderStageType::Compute, Ast::VectorType { 3, Ast::PrimitiveType::UInt32 } } } diff --git a/src/NZSL/LangWriter.cpp b/src/NZSL/LangWriter.cpp index 5a5e0866..b598a397 100644 --- a/src/NZSL/LangWriter.cpp +++ b/src/NZSL/LangWriter.cpp @@ -1790,5 +1790,4 @@ namespace nzsl ScopeVisit(*node.body); } - } diff --git a/src/NZSL/WgslWriter.cpp b/src/NZSL/WgslWriter.cpp new file mode 100644 index 00000000..f632c571 --- /dev/null +++ b/src/NZSL/WgslWriter.cpp @@ -0,0 +1,2628 @@ +// Copyright (C) 2025 kbz_8 (contact@kbz8.me) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nzsl +{ + constexpr std::string_view s_wgslWriterBuiltinEmulationStructName = "_nzslBuiltinEmulation"; + + enum class WgslFeature + { + None = -1, + + // Emulation + EmulateBaseInstance, + EmulateBaseVertex, + EmulateDrawIndex, + + // wgpu native features + WgpuBufferBindingArray, + WgpuConservativeDepth, + WgpuEarlyFragmentTests, + WgpuFloat64, + WgpuPushConstants, + WgpuStorageBindingArray, + WgpuTextureBindingArray, + }; + + struct WgslBuiltin + { + std::string_view identifier; + WgslFeature requiredFeature; + }; + + const auto s_wgslBuiltinMapping = frozen::make_unordered_map({ + { Ast::BuiltinEntry::BaseInstance, { "base_instance", WgslFeature::EmulateBaseInstance } }, + { Ast::BuiltinEntry::BaseVertex, { "base_vertex", WgslFeature::EmulateBaseVertex } }, + { Ast::BuiltinEntry::DrawIndex, { "draw_index", WgslFeature::EmulateDrawIndex } }, + { Ast::BuiltinEntry::FragCoord, { "position", WgslFeature::None } }, + { Ast::BuiltinEntry::FragDepth, { "frag_depth", WgslFeature::None } }, + { Ast::BuiltinEntry::GlocalInvocationIndices, { "global_invocation_id", WgslFeature::None } }, + { Ast::BuiltinEntry::InstanceIndex, { "instance_index", WgslFeature::None } }, + { Ast::BuiltinEntry::LocalInvocationIndex, { "local_invocation_index", WgslFeature::None } }, + { Ast::BuiltinEntry::LocalInvocationIndices, { "local_invocation_id", WgslFeature::None } }, + { Ast::BuiltinEntry::VertexIndex, { "vertex_index", WgslFeature::None } }, + { Ast::BuiltinEntry::VertexPosition, { "position", WgslFeature::None } }, + { Ast::BuiltinEntry::WorkgroupCount, { "num_workgroups", WgslFeature::None } }, + { Ast::BuiltinEntry::WorkgroupIndices, { "workgroup_id", WgslFeature::None } }, + }); + + const std::array s_wgslBuiltinsToEmulate { + Ast::BuiltinEntry::BaseInstance, + Ast::BuiltinEntry::BaseVertex, + Ast::BuiltinEntry::DrawIndex, + }; + + struct WgslWriter::PreVisitor : Ast::RecursiveVisitor + { + PreVisitor(WgslWriter& writer) : m_writer(writer) {} + + void Visit(Ast::DeclareFunctionStatement& node) override + { + if (node.funcIndex) + m_writer.RegisterFunction(*node.funcIndex, node.name); + + if (node.entryStage.HasValue()) + { + ShaderStageType stage = node.entryStage.GetResultingValue(); + + if (stage == ShaderStageType::Fragment) + { + if (node.depthWrite.HasValue() && node.depthWrite.GetResultingValue() != Ast::DepthWriteMode::Replace) + features.insert(WgslFeature::WgpuConservativeDepth); + + if (node.earlyFragmentTests.HasValue() && node.earlyFragmentTests.GetResultingValue()) + features.insert(WgslFeature::WgpuEarlyFragmentTests); + } + + if (!node.parameters.empty()) + { + assert(node.parameters.size() == 1); + auto& parameter = node.parameters.front(); + const auto& parameterType = parameter.type.GetResultingValue(); + + assert(std::holds_alternative(parameterType)); + + std::size_t structIndex = std::get(parameterType).structIndex; + const Ast::StructDescription* structDesc = Nz::Retrieve(structs, structIndex); + + for (const auto& member : structDesc->members) + { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + + if (member.builtin.HasValue()) + { + auto it = s_wgslBuiltinMapping.find(member.builtin.GetResultingValue()); + assert(it != s_wgslBuiltinMapping.end()); + + if (it->second.requiredFeature != WgslFeature::None) + features.insert(it->second.requiredFeature); + } + } + } + } + + RecursiveVisitor::Visit(node); + } + + void Visit(Ast::DeclareExternalStatement& node) override + { + for (const auto& extVar : node.externalVars) + { + const Ast::ExpressionType& type = extVar.type.GetResultingValue(); + if (IsPushConstantType(type)) + features.insert(WgslFeature::WgpuPushConstants); + else if (IsArrayType(type)) + { + const Ast::ArrayType& array = std::get(type); + if (IsStorageType(array.InnerType())) + features.insert(WgslFeature::WgpuStorageBindingArray); + else if (IsTextureType(array.InnerType())) + features.insert(WgslFeature::WgpuTextureBindingArray); + else if (IsStructType(array.InnerType())) + features.insert(WgslFeature::WgpuBufferBindingArray); + } + } + + RecursiveVisitor::Visit(node); + } + + void Visit(Ast::IntrinsicExpression& node) override + { + RecursiveVisitor::Visit(node); + + const Ast::ExpressionType& paramType = ResolveAlias(EnsureExpressionType(*node.parameters[0])); + + if (node.intrinsic == Ast::IntrinsicType::IsInf) + { + assert((IsVectorType(paramType) || IsPrimitiveType(paramType)) && "expected a vector type or a primitive type"); + const Ast::PrimitiveType& type = IsVectorType(paramType) ? std::get(paramType).type : std::get(paramType); + intrinsicHelpers[IntrinsicHelper::Infinity].emplace(type); + } + else if (node.intrinsic == Ast::IntrinsicType::MatrixInverse) + { + assert(IsMatrixType(paramType) && "expected a matrix"); + intrinsicHelpers[IntrinsicHelper::MatrixInverse].emplace(paramType); + } + } + + void Visit(Ast::TypeConstantExpression& node) override + { + assert(IsPrimitiveType(node.type) && "expected a primitive type"); + if (node.typeConstant == Ast::TypeConstant::Infinity) + intrinsicHelpers[IntrinsicHelper::Infinity].emplace(node.type); + else if (node.typeConstant == Ast::TypeConstant::NaN) + intrinsicHelpers[IntrinsicHelper::NaN].emplace(node.type); + } + + void Visit(Ast::DeclareStructStatement& node) override + { + structs[node.structIndex.value()] = &node.description; + RecursiveVisitor::Visit(node); + } + + std::unordered_map structs; + std::unordered_map> intrinsicHelpers; + tsl::ordered_set features; + WgslWriter& m_writer; + }; + + struct WgslWriter::AutoBindingAttribute + { + const Ast::ExpressionValue& autoBinding; + + bool HasValue() const { return autoBinding.HasValue(); } + }; + + struct WgslWriter::AuthorAttribute + { + const std::string& author; + + bool HasValue() const { return !author.empty(); } + }; + + struct WgslWriter::BindingAttribute + { + const Ast::ExpressionValue& bindingIndex; + + bool HasValue() const { return bindingIndex.HasValue(); } + }; + + struct WgslWriter::BuiltinAttribute + { + const Ast::ExpressionValue& builtin; + + bool HasValue() const { return builtin.HasValue(); } + }; + + struct WgslWriter::CondAttribute + { + const Ast::ExpressionValue& cond; + + bool HasValue() const { return cond.HasValue(); } + }; + + struct WgslWriter::DepthWriteAttribute + { + const Ast::ExpressionValue& writeMode; + + bool HasValue() const { return writeMode.HasValue(); } + }; + + struct WgslWriter::DescriptionAttribute + { + const std::string& description; + + bool HasValue() const { return !description.empty(); } + }; + + struct WgslWriter::EarlyFragmentTestsAttribute + { + const Ast::ExpressionValue& earlyFragmentTests; + + bool HasValue() const { return earlyFragmentTests.HasValue(); } + }; + + struct WgslWriter::EntryAttribute + { + const Ast::ExpressionValue& stageType; + + bool HasValue() const { return stageType.HasValue(); } + }; + + struct WgslWriter::FeatureAttribute + { + Ast::ModuleFeature featureAttribute; + + bool HasValue() const { return true; } + }; + + struct WgslWriter::InterpAttribute + { + const Ast::ExpressionValue& interpQualifier; + + bool HasValue() const { return interpQualifier.HasValue(); } + }; + + struct WgslWriter::LicenseAttribute + { + const std::string& license; + + bool HasValue() const { return !license.empty(); } + }; + + struct WgslWriter::LocationAttribute + { + const Ast::ExpressionValue& locationIndex; + + bool HasValue() const { return locationIndex.HasValue(); } + }; + + struct WgslWriter::SetAttribute + { + const Ast::ExpressionValue& setIndex; + + bool HasValue() const { return setIndex.HasValue(); } + }; + + struct WgslWriter::TagAttribute + { + const std::string& tag; + + bool HasValue() const { return !tag.empty(); } + }; + + struct WgslWriter::UnrollAttribute + { + const Ast::ExpressionValue& unroll; + + bool HasValue() const { return unroll.HasValue(); } + }; + + struct WgslWriter::WorkgroupAttribute + { + const Ast::ExpressionValue& workgroup; + + bool HasValue() const { return workgroup.HasValue(); } + }; + + struct WgslWriter::State + { + State(const BackendParameters& backendParameters) : + backendParameters(backendParameters) + { + } + + struct Identifier + { + std::optional externalBlockIndex; + std::size_t moduleIndex; + std::string name; + bool isDereferenceable; + }; + + struct StructData : Identifier + { + const Ast::StructDescription* desc; + }; + + std::optional currentExternalBlockIndex; + std::size_t currentModuleIndex; + std::stringstream stream; + std::unordered_map aliases; + std::unordered_map constants; + std::unordered_map functions; + std::unordered_map modules; + std::unordered_map structs; + std::unordered_map variables; + std::unordered_map bindingRemap; + std::unordered_set reservedBindings; + std::vector externalBlockNames; + std::vector moduleNames; + const BackendParameters& backendParameters; + bool isInEntryPoint = false; + int streamEmptyLine = 1; + unsigned int indentLevel = 0; + bool isTerminatedScope = false; + bool hasf32RatioFunction = false; + bool hasf64RatioFunction = false; + bool hasDrawParametersBaseInstanceUniform = false; + bool hasDrawParametersBaseVertexUniform = false; + bool hasDrawParametersDrawIndexUniform = false; + }; + + WgslWriter::Output WgslWriter::Generate(Ast::Module& module, const BackendParameters& parameters) + { + State state(parameters); + + m_currentState = &state; + NAZARA_DEFER({ m_currentState = nullptr; }); + + if (parameters.backendPasses) + { + Ast::TransformerExecutor executor; + if (parameters.backendPasses.Test(BackendPass::Resolve)) + { + executor.AddPass([&](Ast::ResolveTransformer::Options& opt) + { + opt.moduleResolver = parameters.shaderModuleResolver; + }); + } + + if (parameters.backendPasses.Test(BackendPass::TargetRequired)) + RegisterPasses(executor); + + if (parameters.backendPasses.Test(BackendPass::Optimize)) + executor.AddPass(); + + if (parameters.backendPasses.Test(BackendPass::Validate)) + { + executor.AddPass([](Ast::ValidationTransformer::Options& opt) + { + opt.allowUntyped = false; + opt.checkIndices = true; + }); + } + + Ast::TransformerContext context; + context.optionValues = parameters.optionValues; + + executor.Transform(module, context); + } + + if (parameters.backendPasses.Test(BackendPass::RemoveDeadCode)) + { + Ast::DependencyCheckerVisitor::Config dependencyConfig; + dependencyConfig.usedShaderStages = ShaderStageType_All; + + Ast::EliminateUnusedPass(module, dependencyConfig); + } + + // First registration pass (required to register function names) + PreVisitor previsitor(*this); + { + m_currentState->currentModuleIndex = 0; + for (const auto& importedModule : module.importedModules) + { + m_currentState->currentModuleIndex++; + importedModule.module->rootNode->Visit(previsitor); + m_currentState->moduleNames.push_back(importedModule.identifier); + } + + m_currentState->currentModuleIndex = 0; + + std::size_t moduleIndex = 0; + for (const auto& importedModule : module.importedModules) + RegisterModule(moduleIndex++, importedModule.identifier); + + module.rootNode->Visit(previsitor); + } + + AppendHeader(*module.metadata); + + // Validate required features + auto validateFeature = [&](std::string_view featureName, std::string_view featurePrettyName) + { + if (!m_environment.featuresCallback || !m_environment.featuresCallback(featureName)) + throw std::runtime_error(fmt::format("WGSL does not support {} feature, {}you need to confirm its usage using feature callback", featurePrettyName, (featureName.find("Wgpu") != std::string::npos ? "some implementations do natively but " : ""))); + }; + + for (WgslFeature feature : previsitor.features) + { + switch (feature) + { + case WgslFeature::None: break; + + case WgslFeature::EmulateBaseInstance: + { + validateFeature("EmulateBaseInstance", "base instance attribute"); + m_currentState->hasDrawParametersBaseInstanceUniform = true; + break; + } + case WgslFeature::EmulateBaseVertex: + { + validateFeature("EmulateBaseVertex", "base vertex attribute"); + m_currentState->hasDrawParametersBaseVertexUniform = true; + break; + } + case WgslFeature::EmulateDrawIndex: + { + validateFeature("EmulateDrawIndex", "draw index attribute"); + m_currentState->hasDrawParametersDrawIndexUniform = true; + break; + } + + case WgslFeature::WgpuBufferBindingArray: validateFeature("WgpuBufferBindingArray", "buffer binding array"); break; + case WgslFeature::WgpuConservativeDepth: validateFeature("WgpuConservativeDepth", "conservative depth"); break; + case WgslFeature::WgpuEarlyFragmentTests: validateFeature("WgpuEarlyFragmentTests", "early fragment depth test"); break; + case WgslFeature::WgpuFloat64: validateFeature("WgpuFloat64", "float 64"); break; + case WgslFeature::WgpuPushConstants: validateFeature("WgpuPushConstants", "push constants"); break; + case WgslFeature::WgpuStorageBindingArray: validateFeature("WgpuStorageBindingArray", "storage resource binding array"); break; + case WgslFeature::WgpuTextureBindingArray: validateFeature("WgpuTextureBindingArray", "texture binding array"); break; + } + } + + if (m_currentState->hasDrawParametersBaseInstanceUniform || m_currentState->hasDrawParametersBaseVertexUniform || m_currentState->hasDrawParametersDrawIndexUniform) + { + AppendLine("struct ", s_wgslWriterBuiltinEmulationStructName, "Struct"); + EnterScope(); + { + if (m_currentState->hasDrawParametersBaseInstanceUniform) + AppendLine(s_wgslBuiltinMapping.at(Ast::BuiltinEntry::BaseInstance).identifier, ": u32,"); + if (m_currentState->hasDrawParametersBaseVertexUniform) + AppendLine(s_wgslBuiltinMapping.at(Ast::BuiltinEntry::BaseVertex).identifier, ": u32,"); + if (m_currentState->hasDrawParametersDrawIndexUniform) + AppendLine(s_wgslBuiltinMapping.at(Ast::BuiltinEntry::DrawIndex).identifier, ": u32,"); + } + LeaveScope(); + + const std::uint64_t emulationBindingGroup = 0; + std::uint32_t binding = 0; + for (; m_currentState->reservedBindings.count(emulationBindingGroup << 32 | binding); binding++); + m_currentState->reservedBindings.emplace(emulationBindingGroup << 32 | binding); + AppendLine("@group(", emulationBindingGroup, ") @binding(", binding, ") var ", s_wgslWriterBuiltinEmulationStructName, ": ", s_wgslWriterBuiltinEmulationStructName, "Struct;"); + + AppendLine(); + } + + // Register imported modules + m_currentState->currentModuleIndex = 0; + for (const auto& importedModule : module.importedModules) + { + AppendModuleAttributes(*importedModule.module->metadata); + AppendComment("Module " + importedModule.module->metadata->moduleName); + + m_currentState->currentModuleIndex++; + importedModule.module->rootNode->Visit(*this); + m_currentState->moduleNames.push_back(importedModule.identifier); + } + + for (const auto& [helper, exprTypeSet] : previsitor.intrinsicHelpers) + { + for (const auto& exprType : exprTypeSet) + AppendIntrinsicHelpers(helper, exprType); + } + + m_currentState->currentModuleIndex = 0; + module.rootNode->Visit(*this); + + Output output; + output.code = std::move(state.stream).str(); + output.bindingRemap = std::move(state.bindingRemap); + output.usesDrawParameterBaseInstanceUniform = m_currentState->hasDrawParametersBaseInstanceUniform; + output.usesDrawParameterBaseVertexUniform = m_currentState->hasDrawParametersBaseVertexUniform; + output.usesDrawParameterDrawIndexUniform = m_currentState->hasDrawParametersDrawIndexUniform; + + return output; + } + + void WgslWriter::SetEnv(Environment environment) + { + m_environment = std::move(environment); + } + + void WgslWriter::RegisterPasses(Ast::TransformerExecutor& executor) + { + // Wtf WGSL ? + static constexpr auto s_reservedKeywords = frozen::make_unordered_set({ + "NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment", "async", + "attribute", "auto", "await", "become", "cast", "catch", "class", "co_await", "co_return", "co_yield", + "coherent", "column_major", "common", "compile", "compile_fragment", "concept", "const_cast", "consteval", + "constexpr", "constinit", "crate", "debugger", "decltype", "delete", "demote", "demote_to_helper", + "do", "dynamic_cast", "enum", "explicit", "export", "extends", "extern", "external", "fallthrough", + "filter", "final", "finally", "friend", "from", "fxgroup", "get", "goto", "groupshared", "highp", "impl", + "implements", "import", "inline", "instanceof", "interface", "layout", "lowp", "macro", "macro_rules", + "match", "mediump", "meta", "mod", "module", "move", "mut", "mutable", "namespace", "new", "nil", + "noexcept", "noinline", "nointerpolation", "non_coherent", "noncoherent", "noperspective", "null", + "nullptr", "of", "operator", "package", "packoffset", "partition", "pass", "patch", "pixelfragment", + "precise", "precision", "premerge", "priv", "protected", "pub", "public", "readonly", "ref", "regardless", + "register", "reinterpret_cast", "require", "resource", "restrict", "self", "set", "shared", "sizeof", + "smooth", "snorm", "static", "static_assert", "static_cast", "std", "subroutine", "super", "target", + "template", "this", "thread_local", "throw", "trait", "try", "type", "typedef", "typeid", "typename", + "typeof", "union", "unless", "unorm", "unsafe", "unsized", "use", "using", "varying", "virtual", + "volatile", "wgsl", "where", "with", "writeonly", "yield", "alias", "break", "case", "const", "const_assert", + "continue", "continuing", "default", "diagnostic", "discard", "else", "enable", "false", "fn", "for", + "if", "let", "loop", "override", "requires", "return", "struct", "switch", "true", "var", "while" + }); + + // We need two identifiers passes, the first one to rename reserved/forbidden variable names and the second one to ensure all variables name are uniques (which isn't guaranteed by the transformation passes) + // We can't do this at once at the end because transformations passes will introduce variables prefixed by _nzsl which is forbidden in user code + Ast::IdentifierTransformer::Options firstIdentifierPassOptions; + firstIdentifierPassOptions.makeVariableNameUnique = false; + firstIdentifierPassOptions.identifierSanitizer = [](std::string& identifier, Ast::IdentifierCategory /*scope*/) + { + using namespace std::string_view_literals; + + bool nameChanged = false; + + // Identifier can't start with _nzsl + if (identifier.compare(0, 5, "_nzsl") == 0) + { + identifier.replace(0, 5, "_"sv); + nameChanged = true; + } + + // Identifier can't be only _ + if (identifier == "_") + { + identifier = "_2_2"; + nameChanged = true; + } + + return nameChanged; + }; + + Ast::IdentifierTransformer::Options secondIdentifierPassOptions; + secondIdentifierPassOptions.makeVariableNameUnique = true; + secondIdentifierPassOptions.identifierSanitizer = [](std::string& identifier, Ast::IdentifierCategory /*scope*/) + { + using namespace std::string_view_literals; + + bool nameChanged = false; + while (s_reservedKeywords.count(frozen::string(identifier)) != 0) + { + identifier += '_'; + nameChanged = true; + } + + // Replace __ by _X_ + std::size_t startPos = 0; + while ((startPos = identifier.find("__"sv, startPos)) != std::string::npos) + { + std::size_t endPos = identifier.find_first_not_of('_', startPos); + identifier.replace(startPos, endPos - startPos, fmt::format("{}{}_", (startPos == 0) ? "_" : "", endPos - startPos)); + + startPos = endPos; + nameChanged = true; + } + + return nameChanged; + }; + + executor.AddPass(); + executor.AddPass(); + executor.AddPass(firstIdentifierPassOptions); + executor.AddPass(); + executor.AddPass([](Ast::StructAssignmentTransformer::Options& opt) + { + opt.splitWrappedArrayAssignation = false; + opt.splitWrappedStructAssignation = true; + }); + executor.AddPass([](Ast::SwizzleTransformer::Options& opt) + { + opt.removeScalarSwizzling = true; + opt.removeSwizzleAssigment = true; + }); + executor.AddPass([](Ast::MatrixTransformer::Options& opt) + { + opt.removeMatrixBinaryAddSub = true; + opt.removeMatrixCast = true; + }); + executor.AddPass(); + executor.AddPass([](Ast::ConstantRemovalTransformer::Options& opt) + { + opt.removeConstArraySize = false; + opt.removeTypeConstant = false; + }); + executor.AddPass([](Ast::UniformStructToStd140Transformer::Options& opt) + { + opt.cloneStructIfUsedElsewhere = true; + }); + executor.AddPass(); + executor.AddPass(); + executor.AddPass(secondIdentifierPassOptions); + } + + void WgslWriter::Append(const Ast::AliasType& /*type*/) + { + throw std::runtime_error("unexpected AliasType"); + } + + void WgslWriter::Append(const Ast::ArrayType& type) + { + if (IsSamplerType(type.containedType->type)) + Append("binding_"); + Append("array<", type.containedType->type); + + if (type.length > 0) + Append(", ", type.length); + Append('>'); + } + + void WgslWriter::Append(const Ast::DynArrayType& type) + { + if (IsSamplerType(type.containedType->type)) + Append("binding_"); + Append("array<", type.containedType->type, ">"); + } + + void WgslWriter::Append(const Ast::ExpressionType& type) + { + std::visit([&](auto&& arg) + { + Append(arg); + }, type); + } + + void WgslWriter::Append(const Ast::ExpressionValue& type) + { + assert(type.HasValue()); + if (type.IsResultingValue()) + Append(type.GetResultingValue()); + else + type.GetExpression()->Visit(*this); + } + + void WgslWriter::Append(const Ast::FunctionType& /*functionType*/) + { + throw std::runtime_error("unexpected function type"); + } + + void WgslWriter::Append(const Ast::IntrinsicFunctionType& /*functionType*/) + { + throw std::runtime_error("unexpected intrinsic function type"); + } + + void WgslWriter::Append(const Ast::ImplicitArrayType& /*type*/) + { + throw std::runtime_error("unexpected ImplicitArrayType"); + } + + void WgslWriter::Append(const Ast::ImplicitMatrixType& /*type*/) + { + throw std::runtime_error("unexpected ImplicitMatrixType"); + } + + void WgslWriter::Append(const Ast::ImplicitVectorType& /*type*/) + { + throw std::runtime_error("unexpected ImplicitVectorType"); + } + + void WgslWriter::Append(const Ast::MatrixType& matrixType) + { + Append("mat"); + Append(matrixType.columnCount); + Append("x"); + Append(matrixType.rowCount); + Append("<", matrixType.type, ">"); + } + + void WgslWriter::Append(const Ast::MethodType& /*functionType*/) + { + throw std::runtime_error("unexpected method type"); + } + + void WgslWriter::Append(const Ast::ModuleType& /*moduleType*/) + { + throw std::runtime_error("unexpected module type"); + } + + void WgslWriter::Append(const Ast::NamedExternalBlockType& namedExternalBlockType) + { + AppendComment(m_currentState->externalBlockNames[namedExternalBlockType.namedExternalBlockIndex]); + } + + void WgslWriter::Append(Ast::PrimitiveType type) + { + switch (type) + { + case Ast::PrimitiveType::Boolean: return Append("bool"); + case Ast::PrimitiveType::Float32: return Append("f32"); + case Ast::PrimitiveType::Float64: return Append("f64"); + case Ast::PrimitiveType::Int32: return Append("i32"); + case Ast::PrimitiveType::UInt32: return Append("u32"); + case Ast::PrimitiveType::FloatLiteral: throw std::runtime_error("unexpected untyped float"); + case Ast::PrimitiveType::IntLiteral: throw std::runtime_error("unexpected untyped integer"); + case Ast::PrimitiveType::String: throw std::runtime_error("unexpected string type"); + } + } + + void WgslWriter::Append(const Ast::PushConstantType& pushConstantType) + { + Append(pushConstantType.containedType); + } + + void WgslWriter::Append(const Ast::SamplerType& samplerType) + { + std::string dimension; + std::string type; + switch (samplerType.dim) + { + case ImageType::E1D: + { + if (samplerType.depth) + throw std::runtime_error("depth texture sampler 1D are not supported by WGSL"); + dimension = "1d"; + break; + } + case ImageType::E1D_Array: + throw std::runtime_error("texture 1D array are not supported by WGSL"); + case ImageType::E2D: dimension = "2d"; break; + case ImageType::E2D_Array: dimension = "2d_array"; break; + case ImageType::E3D: + { + if (samplerType.depth) + throw std::runtime_error("depth texture sampler 3D are not supported by WGSL"); + dimension = "3d"; + break; + } + case ImageType::Cubemap: dimension = "cube"; break; + } + switch (samplerType.sampledType) + { + case Ast::PrimitiveType::Boolean: + throw std::runtime_error("unexpected bool type for sampled texture"); + case Ast::PrimitiveType::Float64: + throw std::runtime_error("unexpected f64 type for sampled texture"); + + case Ast::PrimitiveType::Float32: type = ""; break; + case Ast::PrimitiveType::Int32: type = ""; break; + case Ast::PrimitiveType::UInt32: type = ""; break; + + case Ast::PrimitiveType::String: + throw std::runtime_error("unexpected string type for sampled texture"); + + case Ast::PrimitiveType::FloatLiteral: + case Ast::PrimitiveType::IntLiteral: + throw std::runtime_error("unexpected litteral type for sampled texture"); + } + Append("texture_"); + if (samplerType.depth) + Append("depth_", dimension); + else + Append(dimension, type); + } + + void WgslWriter::Append(const Ast::StorageType& storageType) + { + Append(storageType.containedType); + } + + void WgslWriter::Append(const Ast::StructType& structType) + { + AppendIdentifier(m_currentState->structs, structType.structIndex, true); + } + + void WgslWriter::Append(const Ast::TextureType& textureType) + { + Append("texture_storage_"); + switch (textureType.dim) + { + case ImageType::E1D: Append("1d"); break; + case ImageType::E2D: Append("2d"); break; + case ImageType::E2D_Array: Append("2d_array"); break; + case ImageType::E3D: Append("3d"); break; + + default: + throw std::runtime_error("unexpected storage texture type"); + } + Append("<"); + switch (textureType.format) + { + case ImageFormat::RGBA8: Append("rgba8unorm"); break; + case ImageFormat::RGBA8i: Append("rgba8sint"); break; + case ImageFormat::RGBA8Snorm: Append("rgba8snorm"); break; + case ImageFormat::RGBA8ui: Append("rgba8uint"); break; + + case ImageFormat::RGBA16f: Append("rgba16float"); break; + case ImageFormat::RGBA16i: Append("rgba16sint"); break; + case ImageFormat::RGBA16ui: Append("rgba16uint"); break; + + case ImageFormat::R32f: Append("r32float"); break; + case ImageFormat::R32i: Append("r32sint"); break; + case ImageFormat::R32ui: Append("r32uint"); break; + + case ImageFormat::RG32f: Append("rg32float"); break; + case ImageFormat::RG32i: Append("rg32sint"); break; + case ImageFormat::RG32ui: Append("rg32uint"); break; + + case ImageFormat::RGBA32f: Append("rgba32float"); break; + case ImageFormat::RGBA32i: Append("rgba32sint"); break; + case ImageFormat::RGBA32ui: Append("rgba32uint"); break; + + default: + throw std::runtime_error("unexpected format type for texture"); + } + Append(", "); + switch (textureType.accessPolicy) + { + case AccessPolicy::ReadOnly: Append("read"); break; + case AccessPolicy::ReadWrite: Append("read_write"); break; + case AccessPolicy::WriteOnly: Append("write"); break; + } + Append(">"); + } + + void WgslWriter::Append(const Ast::Type& /*type*/) + { + throw std::runtime_error("unexpected type"); + } + + void WgslWriter::Visit(Ast::TypeConstantExpression& node) + { + assert(IsPrimitiveType(node.type)); + Ast::PrimitiveType primitiveType = std::get(node.type); + + auto AppendConstant = [&](auto&& type) + { + using T = std::decay_t; + + if (node.typeConstant == Ast::TypeConstant::Max) + { + if constexpr (std::is_same_v) + return Append("3.402823466e+38"); + else if constexpr (std::is_same_v) + return Append("1.7976931348623158e+308lf"); + else + return AppendValue(Nz::MaxValue()); + } + + if (node.typeConstant == Ast::TypeConstant::Min) + { + if constexpr (std::is_same_v) + return Append("-3.402823466e+38"); + else if constexpr (std::is_same_v) + return Append("-1.7976931348623158e+308lf"); + else + return AppendValue(std::numeric_limits::lowest()); //< Nz::MinValue is implemented by std::numeric_limits::min() which doesn't give the value we want + } + + if constexpr (std::is_floating_point_v) + { + if (node.typeConstant == Ast::TypeConstant::Epsilon) + { + if constexpr (std::is_same_v) + return Append("1.192092896e-07"); + else if constexpr (std::is_same_v) + return Append("2.2204460492503131e-016lf"); + else + static_assert(Nz::AlwaysFalse(), "unhandled type"); + } + + if (node.typeConstant == Ast::TypeConstant::Infinity) + { + if constexpr (std::is_same_v) + return Append("_nzslInfinityf32()"); + else if constexpr (std::is_same_v) + return Append("_nzslInfinityf64()"); + else + static_assert(Nz::AlwaysFalse(), "unhandled type"); + } + + if (node.typeConstant == Ast::TypeConstant::MinPositive) + { + if constexpr (std::is_same_v) + return Append("1.175494351e-38"); + else if constexpr (std::is_same_v) + return Append("2.2250738585072014e-308lf"); + else + static_assert(Nz::AlwaysFalse(), "unhandled type"); + } + + if (node.typeConstant == Ast::TypeConstant::NaN) + { + if constexpr (std::is_same_v) + return Append("_nzslNaNf32()"); + else if constexpr (std::is_same_v) + return Append("_nzslNaNf64()"); + else + static_assert(Nz::AlwaysFalse(), "unhandled type"); + } + } + + throw std::runtime_error("unexpected type constant with type"); + }; + + switch (primitiveType) + { + case Ast::PrimitiveType::Float32: AppendConstant(float{}); break; + case Ast::PrimitiveType::Float64: AppendConstant(double{}); break; + case Ast::PrimitiveType::Int32: AppendConstant(std::int32_t{}); break; + case Ast::PrimitiveType::UInt32: AppendConstant(std::uint32_t{}); break; + + case Ast::PrimitiveType::Boolean: + case Ast::PrimitiveType::FloatLiteral: + case Ast::PrimitiveType::IntLiteral: + case Ast::PrimitiveType::String: + throw std::runtime_error("unexpected primitive type"); + } + } + + void WgslWriter::Append(const Ast::UniformType& uniformType) + { + Append(uniformType.containedType); + } + + void WgslWriter::Append(const Ast::VectorType& vecType) + { + Append("vec", vecType.componentCount, "<", vecType.type, ">"); + } + + void WgslWriter::Append(Ast::NoType) + { + return Append("()"); + } + + template + void WgslWriter::Append(const T& param) + { + assert(m_currentState && "This function should only be called while processing an AST"); + + if (m_currentState->streamEmptyLine > 0) + { + for (std::size_t i = 0; i < m_currentState->indentLevel; ++i) + m_currentState->stream << '\t'; + + m_currentState->streamEmptyLine = 0; + } + + m_currentState->stream << param; + } + + template + void WgslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params) + { + Append(firstParam); + Append(secondParam, std::forward(params)...); + } + + template + void WgslWriter::AppendAttributes(bool appendLine, Args&&... params) + { + bool hasAnyAttribute = (params.HasValue() || ...); + if (!hasAnyAttribute) + return; + + bool first = true; + + AppendAttributesInternal(first, std::forward(params)...); + + if (appendLine) + AppendLine(); + else + Append(" "); + } + + template + void WgslWriter::AppendAttributesInternal(bool& first, const T& param) + { + if (!param.HasValue()) + return; + + AppendAttribute(first, param); + first = false; + } + + template + void WgslWriter::AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params) + { + AppendAttributesInternal(first, firstParam); + AppendAttributesInternal(first, secondParam, std::forward(params)...); + } + + void WgslWriter::AppendAttribute(bool /*first*/, AutoBindingAttribute /*attribute*/) + { + // Nothing to do + } + + void WgslWriter::AppendAttribute(bool /*first*/, AuthorAttribute attribute) + { + if (!attribute.HasValue()) + return; + AppendComment("Author " + EscapeString(attribute.author)); + } + + void WgslWriter::AppendAttribute(bool first, BindingAttribute attribute) + { + if (!attribute.HasValue()) + return; + if (!first) + Append(" "); + Append("@"); + + Append("binding("); + + if (attribute.bindingIndex.IsResultingValue()) + Append(attribute.bindingIndex.GetResultingValue()); + else + attribute.bindingIndex.GetExpression()->Visit(*this); + + Append(")"); + } + + void WgslWriter::AppendAttribute(bool first, BuiltinAttribute attribute) + { + if (!attribute.HasValue()) + return; + auto it = s_wgslBuiltinMapping.find(attribute.builtin.GetResultingValue()); + assert(it != s_wgslBuiltinMapping.end()); + if (it->second.identifier.empty()) + throw std::runtime_error("unsupported builtin attribute!"); + else if (std::find(s_wgslBuiltinsToEmulate.begin(), s_wgslBuiltinsToEmulate.end(), it->first) != s_wgslBuiltinsToEmulate.end()) + return; + if (!first) + Append(" "); + Append("@"); + Append("builtin(", it->second.identifier, ")"); + } + + void WgslWriter::AppendAttribute(bool /*first*/, CondAttribute /*attribute*/) + { + // Nothing to do + } + + void WgslWriter::AppendAttribute(bool first, DepthWriteAttribute attribute) + { + if (!attribute.HasValue() || attribute.writeMode.GetResultingValue() == Ast::DepthWriteMode::Replace) + return; + if (!first) + Append(" "); + switch (attribute.writeMode.GetResultingValue()) + { + case Ast::DepthWriteMode::Greater: Append("@early_depth_test(greater_equal)"); break; + case Ast::DepthWriteMode::Less: Append("@early_depth_test(less_equal)"); break; + case Ast::DepthWriteMode::Replace: break; // Should never be triggered + case Ast::DepthWriteMode::Unchanged: Append("@early_depth_test(unchanged)"); break; + } + } + + void WgslWriter::AppendAttribute(bool /*first*/, DescriptionAttribute attribute) + { + if (!attribute.HasValue()) + return; + AppendComment("Description: " + EscapeString(attribute.description)); + } + + void WgslWriter::AppendAttribute(bool first, EarlyFragmentTestsAttribute attribute) + { + if (!attribute.HasValue() || !attribute.earlyFragmentTests.GetResultingValue()) + return; + if (!first) + Append(" "); + Append("@early_depth_test(force)"); + } + + void WgslWriter::AppendAttribute(bool first, EntryAttribute attribute) + { + if (!attribute.HasValue()) + return; + if (!first) + Append(" "); + Append("@"); + + if (attribute.stageType.IsResultingValue()) + { + switch (attribute.stageType.GetResultingValue()) + { + case ShaderStageType::Compute: Append("compute"); break; + case ShaderStageType::Fragment: Append("fragment"); break; + case ShaderStageType::Vertex: Append("vertex"); break; + } + } + else + attribute.stageType.GetExpression()->Visit(*this); + } + + void WgslWriter::AppendAttribute(bool /*first*/, FeatureAttribute attribute) + { + if (!attribute.HasValue()) + return; + + switch (attribute.featureAttribute) + { + case Ast::ModuleFeature::Float64: + { + if (!m_environment.featuresCallback || !m_environment.featuresCallback("WgpuFloat64")) + throw std::runtime_error("WGSL does not support float64 feature, wgpu does natively but you need to confirm its usage using feature callback"); + break; + } + + case Ast::ModuleFeature::PrimitiveExternals: + throw std::runtime_error("primitive externals have no way to be translated in WGSL"); + break; + + case Ast::ModuleFeature::Texture1D: + // Supported by WGSL + break; + } + } + + void WgslWriter::AppendAttribute(bool first, InterpAttribute attribute) + { + if (!attribute.HasValue()) + return; + if (!first) + Append(" "); + Append("@interpolate("); + + const auto interpQualifierNames = frozen::make_unordered_map({ + { Ast::InterpolationQualifier::Flat, "flat" }, + { Ast::InterpolationQualifier::NoPerspective, "perspective" }, + { Ast::InterpolationQualifier::Smooth, "linear" }, + }); + + if (attribute.interpQualifier.IsResultingValue()) + Append(interpQualifierNames.at(attribute.interpQualifier.GetResultingValue())); + else + attribute.interpQualifier.GetExpression()->Visit(*this); + + Append(")"); + } + + void WgslWriter::AppendAttribute(bool /*first*/, LicenseAttribute attribute) + { + if (!attribute.HasValue()) + return; + AppendComment("License: " + EscapeString(attribute.license)); + } + + void WgslWriter::AppendAttribute(bool first, LocationAttribute attribute) + { + if (!attribute.HasValue()) + return; + if (!first) + Append(" "); + Append("@"); + + Append("location("); + + if (attribute.locationIndex.IsResultingValue()) + Append(attribute.locationIndex.GetResultingValue()); + else + attribute.locationIndex.GetExpression()->Visit(*this); + + Append(")"); + } + + void WgslWriter::AppendAttribute(bool first, SetAttribute attribute) + { + if (!attribute.HasValue()) + return; + if (!first) + Append(" "); + Append("@"); + + Append("group("); + + if (attribute.setIndex.IsResultingValue()) + Append(attribute.setIndex.GetResultingValue()); + else + attribute.setIndex.GetExpression()->Visit(*this); + + Append(")"); + } + + void WgslWriter::AppendAttribute(bool /*first*/, TagAttribute attribute) + { + if (!attribute.HasValue()) + return; + AppendComment("Tag: " + attribute.tag); + } + + void WgslWriter::AppendAttribute(bool /*first*/, UnrollAttribute /*attribute*/) + { + throw std::runtime_error("unexpected unroll attribute, is the shader sanitized?"); + } + + void WgslWriter::AppendAttribute(bool first, WorkgroupAttribute attribute) + { + if (!attribute.HasValue()) + return; + if (!first) + Append(" "); + Append("@"); + + Append("workgroup_size("); + + if (attribute.workgroup.IsResultingValue()) + { + const Vector3u32& workgroupSize = attribute.workgroup.GetResultingValue(); + Append(workgroupSize.x(), ", ", workgroupSize.y(), ", ", workgroupSize.z()); + } + else + { + const Ast::ExpressionPtr& workgroupExpr = attribute.workgroup.GetExpression(); + if (workgroupExpr->GetType() != Ast::NodeType::CastExpression) + throw std::runtime_error("expected workgroup expression to be a cast expression"); + + const Ast::CastExpression& workgroupCast = static_cast(*workgroupExpr); + if (!workgroupCast.targetType.IsResultingValue() || workgroupCast.targetType.GetResultingValue() != Ast::ExpressionType{ Ast::VectorType{ 3, Ast::PrimitiveType::UInt32 }}) + throw std::runtime_error("expected workgroup expression to be a cast to vec3[u32]"); + + if (workgroupCast.expressions.size() != 3) + throw std::runtime_error("expected workgroup expression to be a cast of 3 expressions"); + + workgroupCast.expressions[0]->Visit(*this); + Append(", "); + workgroupCast.expressions[1]->Visit(*this); + Append(", "); + workgroupCast.expressions[2]->Visit(*this); + } + + Append(")"); + } + + void WgslWriter::AppendComment(std::string_view section) + { + std::size_t lineFeed = section.find('\n'); + if (lineFeed != section.npos) + { + std::size_t previousCut = 0; + + AppendLine("/*"); + do + { + AppendLine(section.substr(previousCut, lineFeed - previousCut)); + previousCut = lineFeed + 1; + } + while ((lineFeed = section.find('\n', previousCut)) != section.npos); + AppendLine(section.substr(previousCut)); + AppendLine("*/"); + } + else + AppendLine("// ", section); + } + + void WgslWriter::AppendCommentSection(std::string_view section) + { + assert(m_currentState && "This function should only be called while processing an AST"); + + std::string stars((section.size() < 33) ? (36 - section.size()) / 2 : 3, '*'); + Append("/*", stars, ' ', section, ' ', stars, "*/"); + AppendLine(); + } + + void WgslWriter::AppendIntrinsicHelpers(IntrinsicHelper helper, const Ast::ExpressionType& type) + { + using namespace std::string_view_literals; + + Ast::PrimitiveType primitiveType; + if (IsMatrixType(type)) + primitiveType = std::get(type).type; + else if (IsPrimitiveType(type)) + primitiveType = std::get(type); + else + throw std::runtime_error("expected a matrix type or a primitive type"); + + std::string_view stringPrimitiveType; + switch (primitiveType) + { + case Ast::PrimitiveType::Float32: stringPrimitiveType = "f32"sv; break; + case Ast::PrimitiveType::Float64: stringPrimitiveType = "f64"sv; break; + + default: + throw std::runtime_error(fmt::format("expected primitive type f32 or f64, got {}", ToString(primitiveType))); + } + + auto setupRatioFunction = [this, primitiveType, stringPrimitiveType]() + { + if (primitiveType == Ast::PrimitiveType::Float32) + { + if (m_currentState->hasf32RatioFunction) + return; + m_currentState->hasf32RatioFunction = true; + } + else if (primitiveType == Ast::PrimitiveType::Float64) + { + if (m_currentState->hasf64RatioFunction) + return; + m_currentState->hasf64RatioFunction = true; + } + + Append(fmt::format(R"(fn _nzslRatio{0}(n: {0}, d: {0}) -> {0} +{{ + return n / d; +}} + +)", stringPrimitiveType)); + }; + + switch (helper) + { + case IntrinsicHelper::NaN: + { + setupRatioFunction(); + Append(fmt::format(R"(fn _nzslNaN{0}() -> {0} +{{ + return _nzslRatio{0}(0.0, 0.0); +}} + +)", stringPrimitiveType)); + break; + } + case IntrinsicHelper::Infinity: + { + setupRatioFunction(); + Append(fmt::format(R"(fn _nzslInfinity{0}() -> {0} +{{ + return _nzslRatio{0}(1.0, 0.0); +}} + +)", stringPrimitiveType)); + break; + } + + case IntrinsicHelper::MatrixInverse: + { + const Ast::MatrixType& matrixType = std::get(type); + assert(matrixType.rowCount == matrixType.columnCount); // Should have been catched before WgslWriter + if (matrixType.columnCount == 2) // mat2x2 + { + Append(fmt::format(R"(fn _nzslMatrixInverse2x2{0}(m: mat2x2<{0}>) -> mat2x2<{0}> +{{ + var adj: mat2x2<{0}>; + adj[0][0] = m[1][1]; + adj[0][1] = -m[0][1]; + adj[1][0] = -m[1][0]; + adj[1][1] = m[0][0]; + + let det: {0} = m[0][0] * m[1][1] - m[1][0] * m[0][1]; + return adj * (1 / det); +}} + +)", stringPrimitiveType)); + } + else if (matrixType.columnCount == 3) // mat3x3 + { + Append(fmt::format(R"(fn _nzslMatrixInverse3x3{0}(m: mat3x3<{0}>) -> mat3x3<{0}> +{{ + var adj: mat3x3<{0}>; + + adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]); + adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]); + adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]); + adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]); + adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]); + adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]); + adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]); + adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]); + + let det: {0} = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) + - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) + + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])); + + return adj * (1 / det); +}} + +)", stringPrimitiveType)); + } + else if (matrixType.columnCount == 4) // mat4x4 + { + Append(fmt::format(R"(fn _nzslMatrixInverse4x4{0}(m: mat4x4<{0}>) -> mat4x4<{0}> +{{ + let sub_factor00: {0} = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + let sub_factor01: {0} = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + let sub_factor02: {0} = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + let sub_factor03: {0} = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + let sub_factor04: {0} = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + let sub_factor05: {0} = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + let sub_factor06: {0} = m[1][2] * m[3][3] - m[3][2] * m[1][3]; + let sub_factor07: {0} = m[1][1] * m[3][3] - m[3][1] * m[1][3]; + let sub_factor08: {0} = m[1][1] * m[3][2] - m[3][1] * m[1][2]; + let sub_factor09: {0} = m[1][0] * m[3][3] - m[3][0] * m[1][3]; + let sub_factor10: {0} = m[1][0] * m[3][2] - m[3][0] * m[1][2]; + let sub_factor11: {0} = m[1][1] * m[3][3] - m[3][1] * m[1][3]; + let sub_factor12: {0} = m[1][0] * m[3][1] - m[3][0] * m[1][1]; + let sub_factor13: {0} = m[1][2] * m[2][3] - m[2][2] * m[1][3]; + let sub_factor14: {0} = m[1][1] * m[2][3] - m[2][1] * m[1][3]; + let sub_factor15: {0} = m[1][1] * m[2][2] - m[2][1] * m[1][2]; + let sub_factor16: {0} = m[1][0] * m[2][3] - m[2][0] * m[1][3]; + let sub_factor17: {0} = m[1][0] * m[2][2] - m[2][0] * m[1][2]; + let sub_factor18: {0} = m[1][0] * m[2][1] - m[2][0] * m[1][1]; + + var adj: mat4x4<{0}>; + adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02); + adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04); + adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05); + adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05); + adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02); + adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04); + adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05); + adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05); + adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08); + adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10); + adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12); + adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12); + adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15); + adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17); + adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18); + adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18); + + let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]); + + return adj * (1 / det); +}} + +)", stringPrimitiveType)); + } + break; + } + } + } + + void WgslWriter::AppendLine(std::string_view txt) + { + assert(m_currentState && "This function should only be called while processing an AST"); + + if (txt.empty() && m_currentState->streamEmptyLine > 1) + return; + + m_currentState->stream << txt << '\n'; + m_currentState->streamEmptyLine++; + } + + template + void WgslWriter::AppendIdentifier(const T& map, std::size_t id, bool append_module_prefix) + { + const auto& identifier = Nz::Retrieve(map, id); + if (append_module_prefix && identifier.moduleIndex != 0) + Append(m_currentState->moduleNames[identifier.moduleIndex - 1], '_'); + + Append(identifier.name); + } + + template + void WgslWriter::AppendLine(Args&&... params) + { + (Append(std::forward(params)), ...); + AppendLine(); + } + + template + void WgslWriter::AppendValue(const T& value) + { + if constexpr (std::is_same_v::reference>) + { + // fallback for std::vector + bool v = value; + return AppendValue(v); + } + else if constexpr (IsVector_v) + { + std::string str = Ast::ConstantToString(value); + std::replace(str.begin(), str.end(), '[', '<'); + std::replace(str.begin(), str.end(), ']', '>'); + Append(str); + } + else if constexpr (std::is_same_v) + { + Append(Ast::ToString(value)); + Append("u"); + } + else + Append(Ast::ConstantToString(value)); + } + + void WgslWriter::AppendModuleAttributes(const Ast::Module::Metadata& metadata) + { + for (Ast::ModuleFeature feature : metadata.enabledFeatures) + AppendAttributes(true, FeatureAttribute{ feature }); // Not a real append, it just checks the feature support + AppendAttributes(true, AuthorAttribute{ metadata.author }, DescriptionAttribute{ metadata.description }, LicenseAttribute{ metadata.license }); + } + + void WgslWriter::AppendStatementList(std::vector& statements) + { + bool first = true; + for (const Ast::StatementPtr& statement : statements) + { + if (statement->GetType() == Ast::NodeType::NoOpStatement) + continue; + + if (!first) + AppendLine(); + + statement->Visit(*this); + + first = false; + } + } + + void WgslWriter::EnterScope() + { + assert(m_currentState && "This function should only be called while processing an AST"); + + AppendLine("{"); + m_currentState->indentLevel++; + } + + void WgslWriter::LeaveScope(bool skipLine) + { + assert(m_currentState && "This function should only be called while processing an AST"); + + m_currentState->indentLevel--; + AppendLine(); + + if (skipLine) + AppendLine("}"); + else + Append("}"); + } + + void WgslWriter::RegisterAlias(std::size_t aliasIndex, std::string aliasName) + { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(aliasName); + + assert(m_currentState->aliases.find(aliasIndex) == m_currentState->aliases.end()); + m_currentState->aliases.emplace(aliasIndex, std::move(identifier)); + } + + void WgslWriter::RegisterConstant(std::size_t constantIndex, std::string constantName) + { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(constantName); + + assert(m_currentState->constants.find(constantIndex) == m_currentState->constants.end()); + m_currentState->constants.emplace(constantIndex, std::move(identifier)); + } + + void WgslWriter::RegisterFunction(std::size_t funcIndex, std::string functionName) + { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(functionName); + + assert(m_currentState->functions.find(funcIndex) == m_currentState->functions.end()); + m_currentState->functions.emplace(funcIndex, std::move(identifier)); + } + + void WgslWriter::RegisterModule(std::size_t moduleIndex, std::string moduleName) + { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(moduleName); + + assert(m_currentState->modules.find(moduleIndex) == m_currentState->modules.end()); + m_currentState->modules.emplace(moduleIndex, std::move(identifier)); + } + + void WgslWriter::RegisterStruct(std::size_t structIndex, const Ast::StructDescription& structDescription) + { + State::StructData structData; + structData.moduleIndex = m_currentState->currentModuleIndex; + structData.name = structDescription.name; + structData.desc = &structDescription; + + assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); + m_currentState->structs.emplace(structIndex, std::move(structData)); + } + + void WgslWriter::RegisterVariable(std::size_t varIndex, std::string varName, bool isInout) + { + State::Identifier identifier; + identifier.externalBlockIndex = m_currentState->currentExternalBlockIndex; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(varName); + identifier.isDereferenceable = isInout; + + assert(m_currentState->variables.find(varIndex) == m_currentState->variables.end()); + m_currentState->variables.emplace(varIndex, std::move(identifier)); + } + + void WgslWriter::ScopeVisit(Ast::Statement& node) + { + if (node.GetType() != Ast::NodeType::ScopedStatement) + { + EnterScope(); + node.Visit(*this); + LeaveScope(true); + } + else + node.Visit(*this); + } + + void WgslWriter::Visit(Ast::ExpressionPtr& expr, bool encloseIfRequired) + { + bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) == Ast::ExpressionCategory::Temporary); + + if (enclose) + Append("("); + + expr->Visit(*this); + + if (enclose) + Append(")"); + } + + void WgslWriter::Visit(Ast::AccessFieldExpression& node) + { + // In this implementation we do not visit struct identifier first + // as if we access an emulated builtin we do not want struct's name + // in front. + // Instead we search for member to access, if it is an emulated builtin + // we append it's uniform name. If not we store the access statement + // in a string, visit struct's name and then append the statement. + + const Ast::ExpressionType* exprType = GetExpressionType(*node.expr); + NazaraUnused(exprType); + assert(exprType); + assert(IsStructAddressible(*exprType)); + + std::size_t structIndex = Ast::ResolveStructIndex(*exprType); + assert(structIndex != std::numeric_limits::max()); + + const auto& structData = Nz::Retrieve(m_currentState->structs, structIndex); + + std::string_view memberName; + + std::uint32_t remainingIndices = node.fieldIndex; + for (const auto& member : structData.desc->members) + { + if (member.cond.HasValue() && !member.cond.GetResultingValue()) + continue; + + if (remainingIndices == 0) + { + if (member.builtin.HasValue()) + { + if (std::find(s_wgslBuiltinsToEmulate.begin(), s_wgslBuiltinsToEmulate.end(), member.builtin.GetResultingValue()) != s_wgslBuiltinsToEmulate.end()) + { + auto it = s_wgslBuiltinMapping.find(member.builtin.GetResultingValue()); + assert(it != s_wgslBuiltinMapping.end()); + Append(s_wgslWriterBuiltinEmulationStructName, '.', it->second.identifier); + return; + } + } + memberName = member.name; + break; + } + + remainingIndices--; + } + + Visit(node.expr, true); + Append('.', memberName); + } + + void WgslWriter::Visit(Ast::AccessIdentifierExpression& node) + { + Visit(node.expr, true); + + for (const auto& identifierEntry : node.identifiers) + Append(".", identifierEntry.identifier); + } + + void WgslWriter::Visit(Ast::AccessIndexExpression& node) + { + Visit(node.expr, true); + + for (Ast::ExpressionPtr& expr : node.indices) + { + Append('['); + expr->Visit(*this); + Append(']'); + } + } + + void WgslWriter::Visit(Ast::IdentifierValueExpression& node) + { + switch (node.identifierType) + { + case Ast::IdentifierType::Alias: throw std::runtime_error("unexpected Alias identifier, shader is not properly resolved"); + case Ast::IdentifierType::Intrinsic: throw std::runtime_error("unexpected Intrinsic identifier, shader is not properly resolved"); + case Ast::IdentifierType::Type: throw std::runtime_error("unexpected Type identifier, shader is not properly resolved"); + case Ast::IdentifierType::Unresolved: throw std::runtime_error("unexpected Unresolved identifier, shader is not properly resolved"); + + case Ast::IdentifierType::ExternalBlock: + { + Append(m_currentState->externalBlockNames[node.identifierIndex]); + break; + } + + case Ast::IdentifierType::Module: + { + AppendIdentifier(m_currentState->modules, node.identifierIndex); + break; + } + + case Ast::IdentifierType::Struct: + { + AppendIdentifier(m_currentState->structs, node.identifierIndex, true); + break; + } + + case Ast::IdentifierType::Constant: + { + AppendIdentifier(m_currentState->constants, node.identifierIndex); + break; + } + + case Ast::IdentifierType::Function: + { + AppendIdentifier(m_currentState->functions, node.identifierIndex, true); + break; + } + + case Ast::IdentifierType::Variable: + { + if (m_currentState->variables[node.identifierIndex].isDereferenceable) + Append('*'); + AppendIdentifier(m_currentState->variables, node.identifierIndex); + break; + } + } + } + + void WgslWriter::Visit(Ast::AssignExpression& node) + { + node.left->Visit(*this); + + switch (node.op) + { + case Ast::AssignType::Simple: Append(" = "); break; + case Ast::AssignType::CompoundAdd: Append(" += "); break; + case Ast::AssignType::CompoundDivide: Append(" /= "); break; + case Ast::AssignType::CompoundModulo: Append(" %= "); break; + case Ast::AssignType::CompoundMultiply: Append(" *= "); break; + case Ast::AssignType::CompoundLogicalAnd: Append(" &&= "); break; + case Ast::AssignType::CompoundLogicalOr: Append(" ||= "); break; + case Ast::AssignType::CompoundSubtract: Append(" -= "); break; + } + + node.right->Visit(*this); + } + + void WgslWriter::Visit(Ast::BinaryExpression& node) + { + bool needsClosingCast = false; + + Visit(node.left, true); + + switch (node.op) + { + case Ast::BinaryType::Add: Append(" + "); break; + case Ast::BinaryType::Subtract: Append(" - "); break; + case Ast::BinaryType::Modulo: Append(" % "); break; + case Ast::BinaryType::Multiply: Append(" * "); break; + case Ast::BinaryType::Divide: Append(" / "); break; + + case Ast::BinaryType::CompEq: Append(" == "); break; + case Ast::BinaryType::CompGe: Append(" >= "); break; + case Ast::BinaryType::CompGt: Append(" > "); break; + case Ast::BinaryType::CompLe: Append(" <= "); break; + case Ast::BinaryType::CompLt: Append(" < "); break; + case Ast::BinaryType::CompNe: Append(" != "); break; + + case Ast::BinaryType::LogicalAnd: Append(" && "); break; + case Ast::BinaryType::LogicalOr: Append(" || "); break; + + case Ast::BinaryType::BitwiseAnd: Append(" & "); break; + case Ast::BinaryType::BitwiseOr: Append(" | "); break; + case Ast::BinaryType::BitwiseXor: Append(" ^ "); break; + case Ast::BinaryType::ShiftLeft: Append(" << "); break; + case Ast::BinaryType::ShiftRight: Append(" >> "); break; + } + + if (node.op == Ast::BinaryType::ShiftLeft || node.op == Ast::BinaryType::ShiftRight) + { + const Ast::ExpressionType& rightType = Ast::ResolveAlias(Ast::EnsureExpressionType(*node.right)); + + if (Ast::IsVectorType(rightType)) + { + Ast::VectorType vectorType = std::get(rightType); + if (vectorType.type == Ast::PrimitiveType::Int32) + { + Append("vec"); + Append(std::to_string(vectorType.componentCount)); + Append("("); + needsClosingCast = true; + } + } + else if (Ast::IsPrimitiveType(rightType) && std::get(rightType) == Ast::PrimitiveType::Int32) + { + Append("u32("); + needsClosingCast = true; + } + } + Visit(node.right, true); + if (needsClosingCast) + Append(")"); + } + + void WgslWriter::Visit(Ast::CallFunctionExpression& node) + { + node.targetFunction->Visit(*this); + + Append("("); + for (std::size_t i = 0; i < node.parameters.size(); ++i) + { + if (i != 0) + Append(", "); + if (node.parameters[i].semantic != Ast::FunctionParameterSemantic::In) + Append('&'); + node.parameters[i].expr->Visit(*this); + + const auto& varType = *GetExpressionType(*node.parameters[i].expr); + Ast::ExpressionType rawOrContainedType; + if (IsArrayType(varType)) + rawOrContainedType = std::get(varType).containedType->type; + else if (IsDynArrayType(varType)) + rawOrContainedType = std::get(varType).containedType->type; + else + rawOrContainedType = varType; + if (IsSamplerType(rawOrContainedType)) + { + Append(", "); + node.parameters[i].expr->Visit(*this); + Append("Sampler"); + } + } + Append(")"); + } + + void WgslWriter::Visit(Ast::CastExpression& node) + { + Append(node.targetType); + Append("("); + + bool first = true; + for (const auto& exprPtr : node.expressions) + { + if (!first) + Append(", "); + + first = false; + + exprPtr->Visit(*this); + } + + Append(")"); + } + + void WgslWriter::Visit(Ast::ConditionalExpression& /*node*/) + { + throw std::runtime_error("unexpected conditional expression, is shader sanitized?"); + } + + void WgslWriter::Visit(Ast::ConstantArrayValueExpression& node) + { + Append(*node.cachedExpressionType); + m_currentState->indentLevel++; + AppendLine("("); + std::visit([&](auto&& vec) + { + using T = std::decay_t; + + if constexpr (std::is_same_v) + throw std::runtime_error("unexpected array of NoValue"); + else + { + for (std::size_t i = 0; i < vec.size(); ++i) + { + if (i != 0) + AppendLine(","); + + AppendValue(vec[i]); + } + } + }, node.values); + m_currentState->indentLevel--; + AppendLine(); + Append(")"); + } + + void WgslWriter::Visit(Ast::ConstantValueExpression& node) + { + std::visit([&](auto&& arg) + { + AppendValue(arg); + }, node.value); + } + + void WgslWriter::Visit(Ast::IdentifierExpression& node) + { + Append(node.identifier); + } + + void WgslWriter::Visit(Ast::IntrinsicExpression& node) + { + bool method = false; + bool firstParam = true; + switch (node.intrinsic) + { + // Function intrinsics + case Ast::IntrinsicType::Abs: + case Ast::IntrinsicType::All: + case Ast::IntrinsicType::Any: + case Ast::IntrinsicType::ArcCos: + case Ast::IntrinsicType::ArcCosh: + case Ast::IntrinsicType::ArcSin: + case Ast::IntrinsicType::ArcSinh: + case Ast::IntrinsicType::ArcTan: + case Ast::IntrinsicType::ArcTan2: + case Ast::IntrinsicType::ArcTanh: + case Ast::IntrinsicType::Ceil: + case Ast::IntrinsicType::Clamp: + case Ast::IntrinsicType::Cos: + case Ast::IntrinsicType::Cosh: + case Ast::IntrinsicType::CrossProduct: + case Ast::IntrinsicType::Distance: + case Ast::IntrinsicType::DotProduct: + case Ast::IntrinsicType::Exp: + case Ast::IntrinsicType::Exp2: + case Ast::IntrinsicType::Floor: + case Ast::IntrinsicType::Fract: + case Ast::IntrinsicType::Length: + case Ast::IntrinsicType::Log: + case Ast::IntrinsicType::Log2: + case Ast::IntrinsicType::MatrixTranspose: + case Ast::IntrinsicType::Max: + case Ast::IntrinsicType::Min: + case Ast::IntrinsicType::Normalize: + case Ast::IntrinsicType::Pow: + case Ast::IntrinsicType::Reflect: + case Ast::IntrinsicType::Round: + case Ast::IntrinsicType::Sign: + case Ast::IntrinsicType::Sin: + case Ast::IntrinsicType::Sinh: + case Ast::IntrinsicType::SmoothStep: + case Ast::IntrinsicType::Sqrt: + case Ast::IntrinsicType::Step: + case Ast::IntrinsicType::Tan: + case Ast::IntrinsicType::Tanh: + case Ast::IntrinsicType::Trunc: + { + auto intrinsicIt = LangData::s_intrinsicData.find(node.intrinsic); + assert(intrinsicIt != LangData::s_intrinsicData.end()); + assert(!intrinsicIt->second.functionName.empty()); + + Append(intrinsicIt->second.functionName); + break; + } + + case Ast::IntrinsicType::DegToRad: Append("radians"); break; + case Ast::IntrinsicType::InverseSqrt: Append("inverseSqrt"); break; + + case Ast::IntrinsicType::IsInf: + case Ast::IntrinsicType::IsNaN: + { + const Ast::ExpressionType& paramType = ResolveAlias(EnsureExpressionType(*node.parameters[0])); + const Ast::PrimitiveType& innerType = IsVectorType(paramType) ? std::get(paramType).type : std::get(paramType); + std::size_t componentCount = 1; + if (node.intrinsic == Ast::IntrinsicType::IsInf && IsVectorType(paramType)) + { + componentCount = std::get(paramType).componentCount; + Append("vec", componentCount, "("); + } + for(std::size_t i = 0; i < componentCount; i++) + { + if (i != 0) + Append(", "); + if (node.intrinsic == Ast::IntrinsicType::IsInf) + { + if (IsVectorType(paramType)) + { + const char* componentStr = "xyzw"; + node.parameters[0]->Visit(*this); + Append('.', componentStr[i]); + } + Append(" == _nzslInfinity", (innerType == Ast::PrimitiveType::Float32 ? "f32" : "f64"), "()"); + } + else + { + node.parameters[0]->Visit(*this); + Append(" != "); + node.parameters[0]->Visit(*this); + return; + } + } + if (IsVectorType(paramType)) + Append(")"); + return; + } + + case Ast::IntrinsicType::Lerp: Append("mix"); break; + + case Ast::IntrinsicType::MatrixInverse: + { + assert(IsMatrixType(EnsureExpressionType(*node.parameters[0]))); + const Ast::MatrixType& matrixType = std::get(EnsureExpressionType(*node.parameters[0])); + std::string_view stringPrimitiveType = (matrixType.type == Ast::PrimitiveType::Float32) ? "f32" : "f64"; + if (matrixType.columnCount == 2) + Append("_nzslMatrixInverse2x2", stringPrimitiveType); + else if (matrixType.columnCount == 3) + Append("_nzslMatrixInverse3x3", stringPrimitiveType); + else if (matrixType.columnCount == 4) + Append("_nzslMatrixInverse4x4", stringPrimitiveType); + break; + } + + case Ast::IntrinsicType::Not: Append("!"); break; + case Ast::IntrinsicType::RadToDeg: Append("degrees"); break; + case Ast::IntrinsicType::RoundEven: Append("round"); break; + + case Ast::IntrinsicType::Select: + { + const Ast::ExpressionType& condParamType = ResolveAlias(EnsureExpressionType(*node.parameters[0])); + const Ast::ExpressionType& firstParamType = ResolveAlias(EnsureExpressionType(*node.parameters[1])); + + Append("select("); + node.parameters[2]->Visit(*this); + Append(", "); + node.parameters[1]->Visit(*this); + Append(", "); + + // WGSL requires boolean vectors when selecting vectors + if (IsVectorType(firstParamType) && !IsVectorType(condParamType)) + { + std::size_t componentCount = std::get(firstParamType).componentCount; + + Append("vec", componentCount, "("); + node.parameters[0]->Visit(*this); + Append(")"); + } + else + node.parameters[0]->Visit(*this); + + Append(")"); + return; + } + + case Ast::IntrinsicType::TextureRead: Append("textureLoad"); break; + case Ast::IntrinsicType::TextureWrite: Append("textureStore"); break; + + // Method intrinsics + case Ast::IntrinsicType::ArraySize: + assert(!node.parameters.empty()); + firstParam = false; + if (node.parameters[0]->cachedExpressionType.has_value()) + { + auto value = node.parameters[0]->cachedExpressionType.value(); + if (IsArrayType(value) && std::get(value).length > 0) + { + Append(std::get(value).length); + return; + } + } + Append("arrayLength(&"); + node.parameters[0]->Visit(*this); + method = true; + break; + + case Ast::IntrinsicType::TextureSampleImplicitLod: + { + firstParam = false; + Append("textureSample("); + node.parameters[0]->Visit(*this); + Append(", "); + + if (node.parameters[0]->GetType() == Ast::NodeType::AccessIndexExpression) + { + Ast::AccessIndexExpression* accessExpr = static_cast(node.parameters[0].get()); + accessExpr->expr->Visit(*this); + } + else + node.parameters[0]->Visit(*this); + Append("Sampler, "); + method = true; + + const Ast::ExpressionType& textureType = EnsureExpressionType(*node.parameters[0]); + if (IsSamplerType(textureType) && std::get(textureType).dim == ImageType::E2D_Array) + { + node.parameters[1]->Visit(*this); + Append(".xy, u32("); + node.parameters[1]->Visit(*this); + Append(".z))"); + return; + } + break; + } + + case Ast::IntrinsicType::TextureSampleImplicitLodDepthComp: + { + firstParam = false; + Append("textureSampleCompare("); + node.parameters[0]->Visit(*this); + Append(", "); + + if (node.parameters[0]->GetType() == Ast::NodeType::AccessIndexExpression) + { + Ast::AccessIndexExpression* accessExpr = static_cast(node.parameters[0].get()); + accessExpr->expr->Visit(*this); + } + else + node.parameters[0]->Visit(*this); + Append("Sampler, "); + method = true; + + const Ast::ExpressionType& textureType = EnsureExpressionType(*node.parameters[0]); + if (IsSamplerType(textureType) && std::get(textureType).dim == ImageType::E2D_Array) + { + node.parameters[1]->Visit(*this); + Append(".xy, u32("); + node.parameters[1]->Visit(*this); + Append(".z), "); + node.parameters[2]->Visit(*this); + Append(')'); + return; + } + break; + } + } + + if (firstParam) + Append("("); + bool first = true; + for (std::size_t i = (method) ? 1 : 0; i < node.parameters.size(); ++i) + { + if (!first) + Append(", "); + + first = false; + + node.parameters[i]->Visit(*this); + } + Append(")"); + } + + void WgslWriter::Visit(Ast::SwizzleExpression& node) + { + Visit(node.expression, true); + Append("."); + + const char* componentStr = "xyzw"; + for (std::size_t i = 0; i < node.componentCount; ++i) + Append(componentStr[node.components[i]]); + } + + void WgslWriter::Visit(Ast::UnaryExpression& node) + { + switch (node.op) + { + case Ast::UnaryType::BitwiseNot: + Append("~"); + break; + + case Ast::UnaryType::LogicalNot: + Append("!"); + break; + case Ast::UnaryType::Minus: + Append("-"); + break; + + case Ast::UnaryType::Plus: + break; + } + + node.expression->Visit(*this); + } + + void WgslWriter::Visit(Ast::BranchStatement& node) + { + bool first = true; + for (const auto& statement : node.condStatements) + { + if (!first) + Append("else "); + + Append("if ("); + statement.condition->Visit(*this); + AppendLine(")"); + + ScopeVisit(*statement.statement); + + first = false; + } + + if (node.elseStatement) + { + AppendLine("else"); + + ScopeVisit(*node.elseStatement); + } + } + + void WgslWriter::Visit(Ast::BreakStatement& /*node*/) + { + Append("break;"); + } + + void WgslWriter::Visit(Ast::ConditionalStatement& /*node*/) + { + throw std::runtime_error("unexpected conditional statement, is shader sanitized?"); + } + + void WgslWriter::Visit(Ast::ContinueStatement& /*node*/) + { + Append("continue;"); + } + + void WgslWriter::Visit(Ast::DeclareAliasStatement& /*node*/) + { + // all aliases should have been handled by sanitizer + throw std::runtime_error("unexpected alias declaration, is shader sanitized?"); + } + + void WgslWriter::Visit(Ast::DeclareConstStatement& node) + { + if (node.constIndex) + RegisterConstant(*node.constIndex, node.name); + + Append("const ", node.name); + if (node.type.HasValue()) + Append(": ", node.type); + + if (node.expression) + { + Append(" = "); + node.expression->Visit(*this); + } + + AppendLine(";"); + } + + void WgslWriter::Visit(Ast::DeclareExternalStatement& node) + { + AppendAttributes(true, TagAttribute{ node.tag }); + + if (!node.name.empty()) + { + m_currentState->currentExternalBlockIndex = m_currentState->externalBlockNames.size(); + m_currentState->externalBlockNames.push_back(node.name); + } + + AppendLine(); + + for (const auto& externalVar : node.externalVars) + { + if (!externalVar.tag.empty() && m_currentState->backendParameters.debugLevel >= DebugLevel::Minimal) + AppendAttribute(false, TagAttribute{ externalVar.tag }); + + const Ast::ExpressionType& exprType = externalVar.type.GetResultingValue(); + + std::uint32_t binding = 0; + std::uint64_t bindingSet = (externalVar.bindingSet.HasValue()) ? externalVar.bindingSet.GetResultingValue() : 0; + + // Binding group declaration in WGSL are built like this + // @group(G) @binding(B) var name : TYPE; + + // Binding group handling + if (!IsPushConstantType(exprType)) // Push constants don't have set or binding + { + binding = externalVar.bindingIndex.GetResultingValue(); + for (; m_currentState->reservedBindings.count(bindingSet << 32 | binding); binding++); + m_currentState->reservedBindings.emplace(bindingSet << 32 | binding); + m_currentState->bindingRemap[bindingSet << 32 | externalVar.bindingIndex.GetResultingValue()] = binding; + + AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ Ast::ExpressionValue{ binding } }); + } + + Append("var"); + + // Address space handling + if (IsUniformType(exprType)) + Append(""); + else if (IsPushConstantType(exprType)) + Append(""); + else if (IsStorageType(exprType)) + { + const Ast::StorageType& storageType = std::get(exprType); + + Append(""); + } + + Append(' '); + + std::string variableName; + + if (m_currentState->currentModuleIndex != 0) + variableName += m_currentState->moduleNames[m_currentState->currentModuleIndex - 1] + '_'; + if (!node.name.empty()) + variableName += node.name + '_'; + variableName += externalVar.name; + Append(variableName, ": ", exprType); + + // Apply combined image sampler splitting + { + Ast::ExpressionType rawOrContainedType; + if (IsArrayType(exprType)) + rawOrContainedType = std::get(exprType).containedType->type; + else if (IsDynArrayType(exprType)) + rawOrContainedType = std::get(exprType).containedType->type; + else + rawOrContainedType = exprType; + + if (IsSamplerType(rawOrContainedType)) + { + // WGSL has not (yet?) combined image samplers so we need to split textures and samplers + AppendLine(';'); // Closing last line + AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ Ast::ExpressionValue{ binding + 1 } }); + m_currentState->reservedBindings.emplace(bindingSet << 32 | binding + 1); + Append("var ", variableName, "Sampler: sampler"); + if (std::get(rawOrContainedType).depth) + Append("_comparison"); + } + } + + AppendLine(';'); + + if (externalVar.varIndex) + RegisterVariable(*externalVar.varIndex, variableName); + } + + m_currentState->currentExternalBlockIndex = {}; + } + + void WgslWriter::Visit(Ast::DeclareFunctionStatement& node) + { + assert(m_currentState && "This function should only be called while processing an AST"); + + AppendAttributes(true, + EntryAttribute{ node.entryStage }, + WorkgroupAttribute{ node.workgroupSize }, + EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, + DepthWriteAttribute{ node.depthWrite } + ); + + Append("fn "); + + assert(node.funcIndex); + const auto& identifier = Nz::Retrieve(m_currentState->functions, *node.funcIndex); + if (identifier.moduleIndex != 0) + Append(m_currentState->moduleNames[identifier.moduleIndex - 1], '_', node.name, '('); + else + Append(node.name, '('); + for (std::size_t i = 0; i < node.parameters.size(); ++i) + { + const auto& parameter = node.parameters[i]; + + if (i != 0) + Append(", "); + + Append(parameter.name, ": "); + + if (parameter.semantic != Ast::FunctionParameterSemantic::In) + Append("ptr"); + else + Append(parameter.type); + + if (parameter.varIndex) + RegisterVariable(*parameter.varIndex, parameter.name, parameter.semantic != Ast::FunctionParameterSemantic::In); + + // Should sampler be inout if texture is inout ? + if (parameter.type.IsResultingValue()) + { + Ast::ExpressionType exprType = parameter.type.GetResultingValue(); + Ast::ExpressionType rawOrContainedType; + if (IsArrayType(exprType)) + rawOrContainedType = std::get(exprType).containedType->type; + else if (IsDynArrayType(exprType)) + rawOrContainedType = std::get(exprType).containedType->type; + else + rawOrContainedType = exprType; + + if (IsSamplerType(rawOrContainedType)) + { + if (IsArrayType(exprType) || IsDynArrayType(exprType)) + throw std::runtime_error("WGSL does not support sampled texture array as funtion parameter"); + Append(", ", parameter.name, "Sampler: sampler"); + if (std::get(rawOrContainedType).depth) + Append("_comparison"); + } + } + } + Append(')'); + if (node.returnType.HasValue()) + { + if (!node.returnType.IsResultingValue() || !IsNoType(node.returnType.GetResultingValue())) + Append(" -> ", node.returnType); + } + + AppendLine(); + EnterScope(); + { + AppendStatementList(node.statements); + } + LeaveScope(); + } + + void WgslWriter::Visit(Ast::DeclareOptionStatement& /*node*/) + { + // all options should have been handled by sanitizer + throw std::runtime_error("unexpected option declaration, is shader sanitized?"); + } + + void WgslWriter::Visit(Ast::DeclareStructStatement& node) + { + assert(node.structIndex); + RegisterStruct(*node.structIndex, node.description); + + AppendAttributes(true, TagAttribute{ node.description.tag }); + if (node.description.layout.HasValue() && node.description.layout.GetResultingValue() == Ast::MemoryLayout::Std140) // Only std140 is relevent for now + AppendComment("std140 layout"); + Append("struct "); + + assert(node.structIndex); + const auto& identifier = Nz::Retrieve(m_currentState->structs, *node.structIndex); + if (identifier.moduleIndex != 0) + Append(m_currentState->moduleNames[identifier.moduleIndex - 1], '_'); + AppendLine(node.description.name); + + EnterScope(); + { + const Ast::StructDescription::StructMember* dynArrayMember = nullptr; + bool first = true; + for (const auto& member : node.description.members) + { + // If builtin needs emulation, skip struct declaration as all shader + // input struct members need builtin or location attributes + if (member.builtin.HasValue()) + { + if (std::find(s_wgslBuiltinsToEmulate.begin(), s_wgslBuiltinsToEmulate.end(), member.builtin.GetResultingValue()) != s_wgslBuiltinsToEmulate.end()) + continue; + } + + // Runtime sized arrays should always be last in structs + // See https://www.w3.org/TR/WGSL/#struct-types + if (IsDynArrayType(member.type.GetResultingValue())) + { + if (dynArrayMember != nullptr) + throw std::runtime_error("WGSL structures can only have a single runtime sized array"); + dynArrayMember = &member; + continue; + } + + if (!first) + AppendLine(","); + first = false; + + AppendAttributes(false, CondAttribute{ member.cond }, LocationAttribute{ member.locationIndex }, InterpAttribute{ member.interp }, BuiltinAttribute{ member.builtin }, TagAttribute{ member.tag }); + Append(member.name, ": ", member.type); + } + + if (dynArrayMember) + { + if (!first) + AppendLine(","); + AppendAttributes(false, CondAttribute{ dynArrayMember->cond }, LocationAttribute{ dynArrayMember->locationIndex }, InterpAttribute{ dynArrayMember->interp }, BuiltinAttribute{ dynArrayMember->builtin }, TagAttribute{ dynArrayMember->tag }); + Append(dynArrayMember->name, ": ", dynArrayMember->type); + } + } + LeaveScope(); + } + + void WgslWriter::Visit(Ast::DeclareVariableStatement& node) + { + if (node.varIndex) + RegisterVariable(*node.varIndex, node.varName); + + Append("var "); + Append(node.varName); + if (node.varType.HasValue()) + Append(": ", node.varType); + + if (node.initialExpression) + { + Append(" = "); + node.initialExpression->Visit(*this); + } + + Append(";"); + } + + void WgslWriter::Visit(Ast::DiscardStatement& /*node*/) + { + Append("discard;"); + } + + void WgslWriter::Visit(Ast::ExpressionStatement& node) + { + node.expression->Visit(*this); + Append(";"); + } + + void WgslWriter::Visit(Ast::ForStatement& /*node*/) + { + // For loops must have been converted to while loop in prepasses + throw std::runtime_error("unexpected for statement, is the shader sanitized?"); + } + + void WgslWriter::Visit(Ast::ForEachStatement& /*node*/) + { + throw std::runtime_error("unexpected for each statement, is the shader sanitized?"); + } + + void WgslWriter::Visit(Ast::ImportStatement& /*node*/) + { + throw std::runtime_error("unexpected import statement, is the shader sanitized?"); + } + + void WgslWriter::Visit(Ast::MultiStatement& node) + { + AppendStatementList(node.statements); + } + + void WgslWriter::Visit(Ast::NoOpStatement& /*node*/) + { + /* nothing to do */ + } + + void WgslWriter::Visit(Ast::ReturnStatement& node) + { + if (node.returnExpr) + { + Append("return "); + node.returnExpr->Visit(*this); + Append(";"); + } + else + Append("return;"); + } + + void WgslWriter::Visit(Ast::ScopedStatement& node) + { + EnterScope(); + node.statement->Visit(*this); + LeaveScope(); + } + + void WgslWriter::Visit(Ast::WhileStatement& node) + { + Append("while ("); + node.condition->Visit(*this); + AppendLine(")"); + + ScopeVisit(*node.body); + } + + void WgslWriter::AppendHeader(const Ast::Module::Metadata& metadata) + { + AppendComment("This file was generated by NZSL compiler (Nazara Shading Language)"); + AppendModuleAttributes(metadata); + AppendLine(); + } +} diff --git a/src/ShaderCompiler/Compiler.cpp b/src/ShaderCompiler/Compiler.cpp index 79a02ba0..9ae05947 100644 --- a/src/ShaderCompiler/Compiler.cpp +++ b/src/ShaderCompiler/Compiler.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -230,6 +231,7 @@ namespace nzslc - nzslb : binary NZSL - spv : binary SPIR-V - spv-dis : textual SPIR-V +- wgsl : WGSL Multiple values can be specified using commas (ex: --compile=glsl,nzslb). You can also specify -header as a suffix (ex: --compile=glsl-header) to generate an includable header file. @@ -322,6 +324,8 @@ You can also specify -header as a suffix (ex: --compile=glsl-header) to generate Step("Compile to textual SPIR-V", __LINE__, &Compiler::CompileToSPV, outputFilePath, *targetModule, true); else if (outputType == "glsl") Step("Compile to GLSL", __LINE__, &Compiler::CompileToGLSL, outputFilePath, *targetModule); + else if (outputType == "wgsl") + Step("Compile to WGSL", __LINE__, &Compiler::CompileToWGSL, outputFilePath, *targetModule); else { fmt::print("Unknown format {}, ignoring\n", outputType); @@ -588,6 +592,30 @@ You can also specify -header as a suffix (ex: --compile=glsl-header) to generate } } + void Compiler::CompileToWGSL(std::filesystem::path outputPath, nzsl::Ast::Module& module) + { + // TODO : add a way to validate Wgsl feature usage + nzsl::WgslWriter::Environment env; + env.featuresCallback = [](std::string_view) { return true; }; + + nzsl::WgslWriter wgslWriter; + wgslWriter.SetEnv(env); + + nzsl::BackendParameters states = BuildWriterOptions(); + nzsl::WgslWriter::Output output = wgslWriter.Generate(module, states); + if (m_skipOutput) + return; + + if (m_outputToStdout) + { + OutputToStdout(output.code); + return; + } + + outputPath.replace_extension("wgsl"); + OutputFile(std::move(outputPath), output.code.data(), output.code.size()); + } + nzsl::Ast::ModulePtr Compiler::Deserialize(const std::uint8_t* data, std::size_t size) { nzsl::Deserializer deserializer(data, size); diff --git a/src/ShaderCompiler/Compiler.hpp b/src/ShaderCompiler/Compiler.hpp index 921bcd32..9af19534 100644 --- a/src/ShaderCompiler/Compiler.hpp +++ b/src/ShaderCompiler/Compiler.hpp @@ -59,6 +59,7 @@ namespace nzslc void CompileToNZSL(std::filesystem::path outputPath, const nzsl::Ast::Module& module); void CompileToNZSLB(std::filesystem::path outputPath, const nzsl::Ast::Module& module); void CompileToSPV(std::filesystem::path outputPath, nzsl::Ast::Module& module, bool textual); + void CompileToWGSL(std::filesystem::path outputPath, nzsl::Ast::Module& module); nzsl::Ast::ModulePtr Deserialize(const std::uint8_t* data, std::size_t size); void PrintTime(); void OutputFile(std::filesystem::path filePath, const void* data, std::size_t size, bool disallowHeader = false); diff --git a/tests/src/Tests/AccessMemberTests.cpp b/tests/src/Tests/AccessMemberTests.cpp index dc2431c4..edca8cee 100644 --- a/tests/src/Tests/AccessMemberTests.cpp +++ b/tests/src/Tests/AccessMemberTests.cpp @@ -27,7 +27,6 @@ external [set(0), binding(0)] ubo: uniform[outerStruct] } )"; - nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); ResolveModule(*shaderModule); @@ -40,7 +39,7 @@ external auto swizzle = nzsl::ShaderBuilder::Swizzle(std::move(secondAccess), { 2u }); auto varDecl = nzsl::ShaderBuilder::DeclareVariable("result", nzsl::Ast::ExpressionType{ nzsl::Ast::PrimitiveType::Float32 }, std::move(swizzle)); - shaderModule->rootNode->statements.push_back(nzsl::ShaderBuilder::DeclareFunction(nzsl::ShaderStageType::Vertex, "main", std::move(varDecl))); + shaderModule->rootNode->statements.push_back(nzsl::ShaderBuilder::DeclareFunction(nzsl::ShaderStageType::Fragment, "main", std::move(varDecl))); ExpectGLSL(*shaderModule, R"( void main() @@ -50,7 +49,7 @@ void main() )"); ExpectNZSL(*shaderModule, R"( -[entry(vert)] +[entry(frag)] fn main() { let result: f32 = ubo.s.field.z; @@ -67,6 +66,14 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var result: f32 = ubo.s.field.z; +} +)"); } SECTION("AccessMember with multiples fields") @@ -77,7 +84,7 @@ OpFunctionEnd)"); auto swizzle = nzsl::ShaderBuilder::Swizzle(std::move(access), { 2u }); auto varDecl = nzsl::ShaderBuilder::DeclareVariable("result", nzsl::Ast::ExpressionType{ nzsl::Ast::PrimitiveType::Float32 }, std::move(swizzle)); - shaderModule->rootNode->statements.push_back(nzsl::ShaderBuilder::DeclareFunction(nzsl::ShaderStageType::Vertex, "main", std::move(varDecl))); + shaderModule->rootNode->statements.push_back(nzsl::ShaderBuilder::DeclareFunction(nzsl::ShaderStageType::Fragment, "main", std::move(varDecl))); ExpectGLSL(*shaderModule, R"( void main() @@ -87,7 +94,7 @@ void main() )"); ExpectNZSL(*shaderModule, R"( -[entry(vert)] +[entry(frag)] fn main() { let result: f32 = ubo.s.field.z; @@ -104,6 +111,14 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var result: f32 = ubo.s.field.z; +} +)"); } } } diff --git a/tests/src/Tests/AliasTests.cpp b/tests/src/Tests/AliasTests.cpp index 3cd1c92c..f26b05c6 100644 --- a/tests/src/Tests/AliasTests.cpp +++ b/tests/src/Tests/AliasTests.cpp @@ -3,7 +3,6 @@ #include #include #include -#include TEST_CASE("aliases", "[Shader]") { @@ -15,7 +14,7 @@ module; struct Data { - value: f32 + value: vec4[f32] } alias ExtData = Data; @@ -27,14 +26,14 @@ external struct Input { - value: f32 + [location(0)] value: vec4[f32] } alias In = Input; struct Output { - [location(0)] value: f32 + [location(0)] value: vec4[f32] } alias Out = Output; @@ -56,7 +55,7 @@ fn main(input: In) -> FragOut void main() { Input input_; - input_.value = _nzslInvalue; + input_.value = _nzslVarying0; Output output_; output_.value = extData.value * input_.value; @@ -82,6 +81,8 @@ OpLabel OpVariable OpVariable OpAccessChain +OpCopyMemory +OpAccessChain OpLoad OpAccessChain OpLoad @@ -93,6 +94,16 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main(input: Input) -> Output +{ + var output: Output; + output.value = extData.value * input.value; + return output; +} +)"); } SECTION("Conditional aliases") @@ -218,6 +229,16 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() -> ForwardOutput +{ + var output: ForwardOutput; + output.color = vec4(0.0, 0.0, 1.0, 1.0); + return output; +} +)"); } WHEN("We disable ForwardPass") @@ -299,6 +320,17 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() -> DeferredOutput +{ + var output: DeferredOutput; + output.color = vec4(0.0, 0.0, 1.0, 1.0); + output.normal = vec3(0.0, 1.0, 0.0); + return output; +} +)"); } } } diff --git a/tests/src/Tests/ArithmeticTests.cpp b/tests/src/Tests/ArithmeticTests.cpp index 6106e57f..3f0dcc60 100644 --- a/tests/src/Tests/ArithmeticTests.cpp +++ b/tests/src/Tests/ArithmeticTests.cpp @@ -268,6 +268,41 @@ fn main() OpStore %48 %112 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: i32 = 5; + var y: i32 = 2; + var r: i32 = x + y; + var r_2: i32 = x - y; + var r_3: i32 = x * y; + var r_4: i32 = x / y; + var r_5: i32 = x % y; + var x_2: f32 = 5.0; + var y_2: f32 = 2.0; + var r_6: f32 = x_2 + y_2; + var r_7: f32 = x_2 - y_2; + var r_8: f32 = x_2 * y_2; + var r_9: f32 = x_2 / y_2; + var r_10: f32 = x_2 % y_2; + var x_3: vec2 = vec2(5, 7); + var y_3: vec2 = vec2(2, 3); + var r_11: vec2 = x_3 + y_3; + var r_12: vec2 = x_3 - y_3; + var r_13: vec2 = x_3 * y_3; + var r_14: vec2 = x_3 / y_3; + var r_15: vec2 = x_3 % y_3; + var x_4: vec2 = vec2(5.0, 7.0); + var y_4: vec2 = vec2(2.0, 3.0); + var r_16: vec2 = x_4 + y_4; + var r_17: vec2 = x_4 - y_4; + var r_18: vec2 = x_4 * y_4; + var r_19: vec2 = x_4 / y_4; + var r_20: vec2 = x_4 % y_4; +} +)"); } SECTION("Bitwise operations") @@ -533,6 +568,41 @@ fn main() OpStore %48 %112 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: i32 = 5; + var y: i32 = 2; + var r: i32 = x & y; + var r_2: i32 = x | y; + var r_3: i32 = x ^ y; + var r_4: i32 = x << u32(y); + var r_5: i32 = x >> u32(y); + var x_2: u32 = 5u; + var y_2: u32 = 2u; + var r_6: u32 = x_2 & y_2; + var r_7: u32 = x_2 | y_2; + var r_8: u32 = x_2 ^ y_2; + var r_9: u32 = x_2 << y_2; + var r_10: u32 = x_2 >> y_2; + var x_3: vec3 = vec3(0, 1, 2); + var y_3: vec3 = vec3(2, 1, 0); + var r_11: vec3 = x_3 & y_3; + var r_12: vec3 = x_3 | y_3; + var r_13: vec3 = x_3 ^ y_3; + var r_14: vec3 = x_3 << vec3(y_3); + var r_15: vec3 = x_3 >> vec3(y_3); + var x_4: vec3 = vec3(0u, 1u, 2u); + var y_4: vec3 = vec3(2u, 1u, 0u); + var r_16: vec3 = x_4 & y_4; + var r_17: vec3 = x_4 | y_4; + var r_18: vec3 = x_4 ^ y_4; + var r_19: vec3 = x_4 << y_4; + var r_20: vec3 = x_4 >> y_4; +} +)"); // r_9 and r_10 expressions should perhaps not cast the right node to u32 } SECTION("Matrix/matrix operations") @@ -737,6 +807,39 @@ fn main() OpStore %18 %131 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var _nzsl_matrix: mat3x3; + var _nzsl_cachedResult: f32 = 0.0; + _nzsl_matrix[0u] = vec3(_nzsl_cachedResult, 0.0, 0.0); + _nzsl_matrix[1u] = vec3(0.0, _nzsl_cachedResult, 0.0); + _nzsl_matrix[2u] = vec3(0.0, 0.0, _nzsl_cachedResult); + var x: mat3x3 = _nzsl_matrix; + var _nzsl_matrix_2: mat3x3; + var _nzsl_cachedResult_2: f32 = 1.0; + _nzsl_matrix_2[0u] = vec3(_nzsl_cachedResult_2, 0.0, 0.0); + _nzsl_matrix_2[1u] = vec3(0.0, _nzsl_cachedResult_2, 0.0); + _nzsl_matrix_2[2u] = vec3(0.0, 0.0, _nzsl_cachedResult_2); + var y: mat3x3 = _nzsl_matrix_2; + var _nzsl_matrix_3: mat3x3; + _nzsl_matrix_3[0u] = x[0u] + y[0u]; + _nzsl_matrix_3[1u] = x[1u] + y[1u]; + _nzsl_matrix_3[2u] = x[2u] + y[2u]; + var r: mat3x3 = _nzsl_matrix_3; + var _nzsl_matrix_4: mat3x3; + _nzsl_matrix_4[0u] = x[0u] - y[0u]; + _nzsl_matrix_4[1u] = x[1u] - y[1u]; + _nzsl_matrix_4[2u] = x[2u] - y[2u]; + var r_2: mat3x3 = _nzsl_matrix_4; + var r_3: mat3x3 = x * y; + x += y; + x -= y; + x *= y; +} +)"); } SECTION("Matrix/scalars operations") @@ -815,6 +918,22 @@ fn main() OpStore %22 %39 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var _nzsl_matrix: mat3x3; + var _nzsl_cachedResult: f32 = 1.0; + _nzsl_matrix[0u] = vec3(_nzsl_cachedResult, 0.0, 0.0); + _nzsl_matrix[1u] = vec3(0.0, _nzsl_cachedResult, 0.0); + _nzsl_matrix[2u] = vec3(0.0, 0.0, _nzsl_cachedResult); + var mat: mat3x3 = _nzsl_matrix; + var val: f32 = 42.0; + var r: mat3x3 = mat * val; + var r_2: mat3x3 = val * mat; +} +)"); } SECTION("Vector/vector operations") @@ -957,6 +1076,27 @@ fn main() OpStore %30 %64 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: vec3 = vec3(0.0, 1.0, 2.0); + var y: vec3 = vec3(2.0, 1.0, 0.0); + var r: vec3 = x + y; + var r_2: vec3 = x - y; + var r_3: vec3 = x * y; + var r_4: vec3 = x / y; + var r_5: vec3 = x % y; + var x_2: vec3 = vec3(0u, 1u, 2u); + var y_2: vec3 = vec3(2u, 1u, 0u); + var r_6: vec3 = x_2 + y_2; + var r_7: vec3 = x_2 - y_2; + var r_8: vec3 = x_2 * y_2; + var r_9: vec3 = x_2 / y_2; + var r_10: vec3 = x_2 % y_2; +} +)"); } SECTION("Vector/scalars operations") @@ -1123,6 +1263,29 @@ fn main() OpStore %38 %85 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var vec: vec4 = vec4(1, 2, 3, 4); + var val: i32 = 42; + var r: vec4 = vec * val; + var r_2: vec4 = val * vec; + var r_3: vec4 = vec / val; + var r_4: vec4 = val / vec; + var r_5: vec4 = vec % val; + var r_6: vec4 = val % vec; + var vec_2: vec4 = vec4(1.0, 2.0, 3.0, 4.0); + var val_2: f32 = 42.0; + var r_7: vec4 = vec_2 * val_2; + var r_8: vec4 = val_2 * vec_2; + var r_9: vec4 = vec_2 / val_2; + var r_10: vec4 = val_2 / vec_2; + var r_11: vec4 = vec_2 % val_2; + var r_12: vec4 = val_2 % vec_2; +} +)"); } SECTION("Unary operators") @@ -1234,5 +1397,20 @@ fn main() OpStore %35 %46 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var r: f32 = 42.0; + var r_2: f32 = -6.0; + var r_3: f32 = (-r_2) * (r_2); + var r_4: i32 = ~42; + var r_5: u32 = ~42u; + var r_6: bool = !true; + var r_7: vec3 = ~vec3(1, 2, 3); + var r_8: vec3 = ~vec3(1u, 2u, 3u); +} +)"); } } diff --git a/tests/src/Tests/ArrayTests.cpp b/tests/src/Tests/ArrayTests.cpp index 8386b2da..de771b19 100644 --- a/tests/src/Tests/ArrayTests.cpp +++ b/tests/src/Tests/ArrayTests.cpp @@ -20,7 +20,7 @@ const vertices = array[vec3[f32]]( struct VertIn { - [builtin(vertex_index)] vert_index: i32 + [builtin(vertex_index)] vert_index: u32 } struct VertOut @@ -32,7 +32,7 @@ struct VertOut fn main(input: VertIn) -> VertOut { let output: VertOut; - if (u32(input.vert_index) < vertices.Size()) + if (input.vert_index < vertices.Size()) output.pos = vec4[f32](vertices[input.vert_index], 1.0); else output.pos = vec4[f32](0.0, 0.0, 0.0, 0.0); @@ -56,7 +56,7 @@ vec3 vertices[3] = vec3[3]( ); struct VertIn { - int vert_index; + uint vert_index; }; struct VertOut @@ -67,10 +67,10 @@ struct VertOut void main() { VertIn input_; - input_.vert_index = gl_VertexID; + input_.vert_index = uint(gl_VertexID); VertOut output_; - if ((uint(input_.vert_index)) < (uint(vertices.length()))) + if (input_.vert_index < (uint(vertices.length()))) { output_.pos = vec4(vertices[input_.vert_index], 1.0); } @@ -90,7 +90,7 @@ const vertices: array[vec3[f32], 3] = array[vec3[f32], 3](vec3[f32](1.0, 2.0, 3. struct VertIn { - [builtin(vertex_index)] vert_index: i32 + [builtin(vertex_index)] vert_index: u32 } struct VertOut @@ -102,7 +102,7 @@ struct VertOut fn main(input: VertIn) -> VertOut { let output: VertOut; - if ((u32(input.vert_index)) < (vertices.Size())) + if (input.vert_index < (vertices.Size())) { output.pos = vec4[f32](vertices[input.vert_index], 1.0); } @@ -137,11 +137,11 @@ fn main(input: VertIn) -> VertOut %19 = OpConstantComposite %5 %10 %14 %18 %21 = OpTypeVoid %22 = OpTypeFunction %21 -%23 = OpTypeInt 32 1 -%24 = OpTypePointer StorageClass(Input) %23 -%26 = OpConstant %23 i32(0) -%27 = OpTypePointer StorageClass(Function) %23 -%28 = OpTypeStruct %23 +%23 = OpTypePointer StorageClass(Input) %3 +%25 = OpTypeInt 32 1 +%26 = OpConstant %25 i32(0) +%27 = OpTypePointer StorageClass(Function) %3 +%28 = OpTypeStruct %3 %29 = OpTypePointer StorageClass(Function) %28 %30 = OpTypeVector %1 4 %31 = OpTypePointer StorageClass(Output) %30 @@ -149,18 +149,18 @@ fn main(input: VertIn) -> VertOut %34 = OpTypePointer StorageClass(Function) %33 %35 = OpTypeBool %36 = OpConstant %1 f32(0) -%37 = OpConstant %23 i32(1) -%38 = OpConstant %23 i32(2) -%39 = OpConstant %23 i32(3) -%40 = OpConstant %23 i32(4) -%41 = OpConstant %23 i32(5) +%37 = OpConstant %25 i32(1) +%38 = OpConstant %25 i32(2) +%39 = OpConstant %25 i32(3) +%40 = OpConstant %25 i32(4) +%41 = OpConstant %25 i32(5) %42 = OpConstant %3 u32(5) -%43 = OpTypeArray %23 %42 +%43 = OpTypeArray %25 %42 %44 = OpTypePointer StorageClass(Function) %43 -%60 = OpTypePointer StorageClass(Private) %2 -%65 = OpTypePointer StorageClass(Function) %30 +%59 = OpTypePointer StorageClass(Private) %2 +%64 = OpTypePointer StorageClass(Function) %30 %20 = OpVariable %6 StorageClass(Private) %19 -%25 = OpVariable %24 StorageClass(Input) +%24 = OpVariable %23 StorageClass(Input) %32 = OpVariable %31 StorageClass(Output) %45 = OpFunction %21 FunctionControl(0) %22 %46 = OpLabel @@ -168,34 +168,67 @@ fn main(input: VertIn) -> VertOut %48 = OpVariable %44 StorageClass(Function) %49 = OpVariable %29 StorageClass(Function) %50 = OpAccessChain %27 %49 %26 - OpCopyMemory %50 %25 + OpCopyMemory %50 %24 %54 = OpAccessChain %27 %49 %26 -%55 = OpLoad %23 %54 -%56 = OpBitcast %3 %55 -%57 = OpULessThan %35 %56 %4 +%55 = OpLoad %3 %54 +%56 = OpULessThan %35 %55 %4 OpSelectionMerge %51 SelectionControl(0) - OpBranchConditional %57 %52 %53 + OpBranchConditional %56 %52 %53 %52 = OpLabel -%58 = OpAccessChain %27 %49 %26 -%59 = OpLoad %23 %58 -%61 = OpAccessChain %60 %20 %59 -%62 = OpLoad %2 %61 -%63 = OpCompositeConstruct %30 %62 %7 -%64 = OpAccessChain %65 %47 %26 - OpStore %64 %63 +%57 = OpAccessChain %27 %49 %26 +%58 = OpLoad %3 %57 +%60 = OpAccessChain %59 %20 %58 +%61 = OpLoad %2 %60 +%62 = OpCompositeConstruct %30 %61 %7 +%63 = OpAccessChain %64 %47 %26 + OpStore %63 %62 OpBranch %51 %53 = OpLabel -%66 = OpCompositeConstruct %30 %36 %36 %36 %36 -%67 = OpAccessChain %65 %47 %26 - OpStore %67 %66 +%65 = OpCompositeConstruct %30 %36 %36 %36 %36 +%66 = OpAccessChain %64 %47 %26 + OpStore %66 %65 OpBranch %51 %51 = OpLabel -%68 = OpCompositeConstruct %43 %37 %38 %39 %40 %41 - OpStore %48 %68 -%69 = OpLoad %33 %47 -%70 = OpCompositeExtract %30 %69 0 - OpStore %32 %70 +%67 = OpCompositeConstruct %43 %37 %38 %39 %40 %41 + OpStore %48 %67 +%68 = OpLoad %33 %47 +%69 = OpCompositeExtract %30 %68 0 + OpStore %32 %69 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +const vertices: array, 3> = array, 3>( + vec3(1.0, 2.0, 3.0), + vec3(4.0, 5.0, 6.0), + vec3(7.0, 8.0, 9.0) +); + +struct VertIn +{ + @builtin(vertex_index) vert_index: u32 +} + +struct VertOut +{ + @builtin(position) pos: vec4 +} + +@vertex +fn main(input: VertIn) -> VertOut +{ + var output: VertOut; + if (input.vert_index < (3)) + { + output.pos = vec4(vertices[input.vert_index], 1.0); + } + else + { + output.pos = vec4(0.0, 0.0, 0.0, 0.0); + } + + var customData: array = array(1, 2, 3, 4, 5); + return output; +})"); } } diff --git a/tests/src/Tests/BranchTests.cpp b/tests/src/Tests/BranchTests.cpp index 328ad64d..00c297c7 100644 --- a/tests/src/Tests/BranchTests.cpp +++ b/tests/src/Tests/BranchTests.cpp @@ -87,6 +87,23 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var value: f32; + if (data.value > 0.0) + { + value = 1.0; + } + else + { + value = 0.0; + } + +} +)"); } WHEN("using a more complex branch") @@ -178,6 +195,23 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var value: f32; + if ((data.value > 42.0) || ((data.value <= 50.0) && (data.value < 0.0))) + { + value = 1.0; + } + else + { + value = 0.0; + } + +} +)"); } WHEN("discarding in a branch") @@ -269,6 +303,21 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() -> Output +{ + if (data.value > 0.0) + { + discard; + } + + var output: Output; + output.color = vec4(1.0, 1.0, 1.0, 1.0); + return output; +} +)"); } WHEN("discarding in a const branch") @@ -337,6 +386,20 @@ OpLabel OpVariable OpKill OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() -> Output +{ + { + discard; + } + + var output: Output; + output.color = vec4(1.0, 1.0, 1.0, 1.0); + return output; +} +)"); } WHEN("using a complex branch") @@ -462,5 +525,30 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var value: f32; + if (data.value >= 3.0) + { + value = 3.0; + } + else if (data.value > 2.0) + { + value = 2.0; + } + else if (data.value > 1.0) + { + value = 1.0; + } + else + { + value = 0.0; + } + +} +)"); } } diff --git a/tests/src/Tests/BuiltinAttributeTests.cpp b/tests/src/Tests/BuiltinAttributeTests.cpp index 3c7ed8d7..36b1cf2a 100644 --- a/tests/src/Tests/BuiltinAttributeTests.cpp +++ b/tests/src/Tests/BuiltinAttributeTests.cpp @@ -15,11 +15,11 @@ module; struct Input { - [builtin(base_instance)] base_instance: i32, - [builtin(base_vertex)] base_vertex: i32, - [builtin(draw_index)] draw_index: i32, - [builtin(instance_index)] instance_index: i32, - [builtin(vertex_index)] vertex_index: i32, + [builtin(base_instance)] base_instance: u32, + [builtin(base_vertex)] base_vertex: u32, + [builtin(draw_index)] draw_index: u32, + [builtin(instance_index)] instance_index: u32, + [builtin(vertex_index)] vertex_index: u32, } struct Output @@ -67,11 +67,11 @@ fn main(input: Input) -> Output struct Input { - int base_instance; - int base_vertex; - int draw_index; - int instance_index; - int vertex_index; + uint base_instance; + uint base_vertex; + uint draw_index; + uint instance_index; + uint vertex_index; }; struct Output @@ -82,17 +82,17 @@ struct Output void main() { Input input_; - input_.base_instance = gl_BaseInstanceARB; - input_.base_vertex = gl_BaseVertexARB; - input_.draw_index = gl_DrawIDARB; - input_.instance_index = (gl_BaseInstanceARB + gl_InstanceID); - input_.vertex_index = gl_VertexID; - - int bi = input_.base_instance; - int bv = input_.base_vertex; - int di = input_.draw_index; - int ii = input_.instance_index; - int vi = input_.vertex_index; + input_.base_instance = uint(gl_BaseInstanceARB); + input_.base_vertex = uint(gl_BaseVertexARB); + input_.draw_index = uint(gl_DrawIDARB); + input_.instance_index = uint(gl_BaseInstanceARB) + uint(gl_InstanceID); + input_.vertex_index = uint(gl_VertexID); + + uint bi = input_.base_instance; + uint bv = input_.base_vertex; + uint di = input_.draw_index; + uint ii = input_.instance_index; + uint vi = input_.vertex_index; float color = float((((bi + bv) + di) + ii) + vi); Output output_; output_.position = vec4(color, color, color, color); @@ -112,11 +112,11 @@ void main() ExpectGLSL(*shaderModule, R"( struct Input { - int base_instance; - int base_vertex; - int draw_index; - int instance_index; - int vertex_index; + uint base_instance; + uint base_vertex; + uint draw_index; + uint instance_index; + uint vertex_index; }; struct Output @@ -127,17 +127,17 @@ struct Output void main() { Input input_; - input_.base_instance = gl_BaseInstance; - input_.base_vertex = gl_BaseVertex; - input_.draw_index = gl_DrawID; - input_.instance_index = (gl_BaseInstance + gl_InstanceID); - input_.vertex_index = gl_VertexID; - - int bi = input_.base_instance; - int bv = input_.base_vertex; - int di = input_.draw_index; - int ii = input_.instance_index; - int vi = input_.vertex_index; + input_.base_instance = uint(gl_BaseInstance); + input_.base_vertex = uint(gl_BaseVertex); + input_.draw_index = uint(gl_DrawID); + input_.instance_index = uint(gl_BaseInstance) + uint(gl_InstanceID); + input_.vertex_index = uint(gl_VertexID); + + uint bi = input_.base_instance; + uint bv = input_.base_vertex; + uint di = input_.draw_index; + uint ii = input_.instance_index; + uint vi = input_.vertex_index; float color = float((((bi + bv) + di) + ii) + vi); Output output_; output_.position = vec4(color, color, color, color); @@ -153,19 +153,19 @@ void main() glslEnv.allowDrawParametersUniformsFallback = true; ExpectGLSL(*shaderModule, R"( -uniform int _nzslBaseInstance; -uniform int _nzslBaseVertex; -uniform int _nzslDrawID; +uniform uint _nzslBaseInstance; +uniform uint _nzslBaseVertex; +uniform uint _nzslDrawID; // header end struct Input { - int base_instance; - int base_vertex; - int draw_index; - int instance_index; - int vertex_index; + uint base_instance; + uint base_vertex; + uint draw_index; + uint instance_index; + uint vertex_index; }; struct Output @@ -179,14 +179,14 @@ void main() input_.base_instance = _nzslBaseInstance; input_.base_vertex = _nzslBaseVertex; input_.draw_index = _nzslDrawID; - input_.instance_index = (_nzslBaseInstance + gl_InstanceID); - input_.vertex_index = gl_VertexID; - - int bi = input_.base_instance; - int bv = input_.base_vertex; - int di = input_.draw_index; - int ii = input_.instance_index; - int vi = input_.vertex_index; + input_.instance_index = _nzslBaseInstance + uint(gl_InstanceID); + input_.vertex_index = uint(gl_VertexID); + + uint bi = input_.base_instance; + uint bv = input_.base_vertex; + uint di = input_.draw_index; + uint ii = input_.instance_index; + uint vi = input_.vertex_index; float color = float((((bi + bv) + di) + ii) + vi); Output output_; output_.position = vec4(color, color, color, color); @@ -200,11 +200,11 @@ void main() ExpectNZSL(*shaderModule, R"( struct Input { - [builtin(base_instance)] base_instance: i32, - [builtin(base_vertex)] base_vertex: i32, - [builtin(draw_index)] draw_index: i32, - [builtin(instance_index)] instance_index: i32, - [builtin(vertex_index)] vertex_index: i32 + [builtin(base_instance)] base_instance: u32, + [builtin(base_vertex)] base_vertex: u32, + [builtin(draw_index)] draw_index: u32, + [builtin(instance_index)] instance_index: u32, + [builtin(vertex_index)] vertex_index: u32 } struct Output @@ -215,11 +215,11 @@ struct Output [entry(vert)] fn main(input: Input) -> Output { - let bi: i32 = input.base_instance; - let bv: i32 = input.base_vertex; - let di: i32 = input.draw_index; - let ii: i32 = input.instance_index; - let vi: i32 = input.vertex_index; + let bi: u32 = input.base_instance; + let bv: u32 = input.base_vertex; + let di: u32 = input.draw_index; + let ii: u32 = input.instance_index; + let vi: u32 = input.vertex_index; let color: f32 = f32((((bi + bv) + di) + ii) + vi); let output: Output; output.position = color.xxxx; @@ -240,12 +240,51 @@ fn main(input: Input) -> Output ExpectSPIRV(*shaderModule, R"( OpDecorate %5 Decoration(BuiltIn) BuiltIn(BaseInstance) - OpDecorate %8 Decoration(BuiltIn) BuiltIn(BaseVertex) - OpDecorate %10 Decoration(BuiltIn) BuiltIn(DrawIndex) - OpDecorate %12 Decoration(BuiltIn) BuiltIn(InstanceIndex) - OpDecorate %14 Decoration(BuiltIn) BuiltIn(VertexIndex) - OpDecorate %21 Decoration(BuiltIn) BuiltIn(Position))", {}, spirvEnv, true); + OpDecorate %9 Decoration(BuiltIn) BuiltIn(BaseVertex) + OpDecorate %11 Decoration(BuiltIn) BuiltIn(DrawIndex) + OpDecorate %13 Decoration(BuiltIn) BuiltIn(InstanceIndex) + OpDecorate %15 Decoration(BuiltIn) BuiltIn(VertexIndex) + OpDecorate %22 Decoration(BuiltIn) BuiltIn(Position))", {}, spirvEnv, true); } + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +struct _nzslBuiltinEmulationStruct +{ + base_instance: u32, + base_vertex: u32, + draw_index: u32, + +} +@group(0) @binding(0) var _nzslBuiltinEmulation: _nzslBuiltinEmulationStruct; + +struct Input +{ + @builtin(instance_index) instance_index: u32, + @builtin(vertex_index) vertex_index: u32 +} + +struct Output +{ + @builtin(position) position: vec4 +} + +@vertex +fn main(input: Input) -> Output +{ + var bi: u32 = _nzslBuiltinEmulation.base_instance; + var bv: u32 = _nzslBuiltinEmulation.base_vertex; + var di: u32 = _nzslBuiltinEmulation.draw_index; + var ii: u32 = input.instance_index; + var vi: u32 = input.vertex_index; + var color: f32 = f32((((bi + bv) + di) + ii) + vi); + var output: Output; + output.position = vec4(color, color, color, color); + return output; +} +)", {}, wgslEnv); } SECTION("vertex index") @@ -256,7 +295,7 @@ module; struct Input { - [builtin(vertex_index)] vert_index: i32 + [builtin(vertex_index)] vert_index: u32 } struct Output @@ -281,7 +320,7 @@ fn main(input: Input) -> Output ExpectGLSL(*shaderModule, R"( struct Input { - int vert_index; + uint vert_index; }; struct Output @@ -292,7 +331,7 @@ struct Output void main() { Input input_; - input_.vert_index = gl_VertexID; + input_.vert_index = uint(gl_VertexID); float color = float(input_.vert_index); Output output_; @@ -306,7 +345,7 @@ void main() ExpectNZSL(*shaderModule, R"( struct Input { - [builtin(vertex_index)] vert_index: i32 + [builtin(vertex_index)] vert_index: u32 } struct Output @@ -325,6 +364,27 @@ fn main(input: Input) -> Output )"); ExpectSPIRV(*shaderModule, R"(OpDecorate %5 Decoration(BuiltIn) BuiltIn(VertexIndex))", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +struct Input +{ + @builtin(vertex_index) vert_index: u32 +} + +struct Output +{ + @builtin(position) position: vec4 +} + +@vertex +fn main(input: Input) -> Output +{ + var color: f32 = f32(input.vert_index); + var output: Output; + output.position = vec4(color, color, color, color); + return output; +} +)"); } SECTION("vertex position") @@ -428,5 +488,15 @@ fn main() -> Output )"); ExpectSPIRV(*shaderModule, R"(OpDecorate %6 Decoration(BuiltIn) BuiltIn(Position))", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@vertex +fn main() -> Output +{ + var output: Output; + output.position = vec4(0.0, 0.5, 1.0, 1.0); + return output; +} +)"); } } diff --git a/tests/src/Tests/CastTests.cpp b/tests/src/Tests/CastTests.cpp index 120deade..51c35c1e 100644 --- a/tests/src/Tests/CastTests.cpp +++ b/tests/src/Tests/CastTests.cpp @@ -2,7 +2,6 @@ #include #include #include -#include TEST_CASE("Casts", "[Shader]") { @@ -136,5 +135,28 @@ fn main() OpStore %28 %47 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var fVal: f32 = 42.0; + var x: f64 = f64(fVal); + var x_2: i32 = i32(fVal); + var x_3: u32 = u32(fVal); + var iVal: i32 = 42; + var x_4: f32 = f32(iVal); + var x_5: f64 = f64(iVal); + var x_6: u32 = u32(iVal); + var uVal: u32 = 42u; + var x_7: f32 = f32(uVal); + var x_8: f64 = f64(uVal); + var x_9: i32 = i32(uVal); + var fToIVal: f32 = f32(42); +} +)", {}, wgslEnv); } } diff --git a/tests/src/Tests/ComparisonTests.cpp b/tests/src/Tests/ComparisonTests.cpp index 92248f00..c81d2f90 100644 --- a/tests/src/Tests/ComparisonTests.cpp +++ b/tests/src/Tests/ComparisonTests.cpp @@ -338,6 +338,49 @@ fn main() OpStore %64 %148 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: i32 = 5; + var y: i32 = 2; + var r: bool = x == y; + var r_2: bool = x != y; + var r_3: bool = x < y; + var r_4: bool = x <= y; + var r_5: bool = x > y; + var r_6: bool = x >= y; + var x_2: f32 = 5.0; + var y_2: f32 = 2.0; + var r_7: bool = x_2 == y_2; + var r_8: bool = x_2 != y_2; + var r_9: bool = x_2 < y_2; + var r_10: bool = x_2 <= y_2; + var r_11: bool = x_2 > y_2; + var r_12: bool = x_2 >= y_2; + var x_3: vec2 = vec2(5, 7); + var y_3: vec2 = vec2(2, 3); + var r_13: vec2 = x_3 == y_3; + var r_14: vec2 = x_3 != y_3; + var r_15: vec2 = x_3 < y_3; + var r_16: vec2 = x_3 <= y_3; + var r_17: vec2 = x_3 > y_3; + var r_18: vec2 = x_3 >= y_3; + var x_4: vec2 = vec2(5.0, 7.0); + var y_4: vec2 = vec2(2.0, 3.0); + var r_19: vec2 = x_4 == y_4; + var r_20: vec2 = x_4 != y_4; + var r_21: vec2 = x_4 < y_4; + var r_22: vec2 = x_4 <= y_4; + var r_23: vec2 = x_4 > y_4; + var r_24: vec2 = x_4 >= y_4; + var x_5: vec3 = vec3(true, false, true); + var y_5: vec3 = vec3(false, false, true); + var r_25: vec3 = x_5 == y_5; + var r_26: vec3 = x_5 != y_5; +} +)"); } SECTION("Unary operators combined with binary operators") @@ -440,5 +483,26 @@ fn main() OpStore %17 %25 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +fn foo() -> bool +{ + return false; +} + +fn bar() -> bool +{ + return true; +} + +@fragment +fn main() +{ + var x: bool = false; + var y: bool = true; + var z: bool = (!x) || y; + var z_2: bool = (!foo()) || (bar()); +} +)"); } } diff --git a/tests/src/Tests/ComputeTests.cpp b/tests/src/Tests/ComputeTests.cpp index cab40a75..fdd39017 100644 --- a/tests/src/Tests/ComputeTests.cpp +++ b/tests/src/Tests/ComputeTests.cpp @@ -226,5 +226,36 @@ fn main(input: Input) OpImageWrite %56 %60 %61 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +// std140 layout +struct Data +{ + tex_size: vec2, + _padding0: f32, + _padding1: f32 +} + +@group(0) @binding(0) var input_tex: texture_storage_2d; +@group(0) @binding(1) var output_tex: texture_storage_2d; +@group(0) @binding(2) var data: Data; + +struct Input +{ + @builtin(global_invocation_id) indices: vec3 +} + +@compute @workgroup_size(32, 32, 1) +fn main(input: Input) +{ + if ((input.indices.x >= data.tex_size.x) || (input.indices.y >= data.tex_size.y)) + { + return; + } + + var value: vec4 = textureLoad(input_tex, vec2(input.indices.xy)); + textureStore(output_tex, vec2(input.indices.xy), value); +} +)"); } } diff --git a/tests/src/Tests/ConstantTests.cpp b/tests/src/Tests/ConstantTests.cpp index 27da405c..70448c47 100644 --- a/tests/src/Tests/ConstantTests.cpp +++ b/tests/src/Tests/ConstantTests.cpp @@ -209,5 +209,31 @@ fn main() OpStore %44 %26 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var foo: f32 = 1.192092896e-07; + var foo_2: f32 = 3.402823466e+38; + var foo_3: f32 = -3.402823466e+38; + var foo_4: f32 = 1.175494351e-38; + var foo_5: f32 = _nzslInfinityf32(); + var foo_6: f32 = _nzslNaNf32(); + var foo_7: f64 = 2.2204460492503131e-016lf; + var foo_8: f64 = 1.7976931348623158e+308lf; + var foo_9: f64 = -1.7976931348623158e+308lf; + var foo_10: f64 = 2.2250738585072014e-308lf; + var foo_11: f64 = _nzslInfinityf64(); + var foo_12: f64 = _nzslNaNf64(); + var foo_13: i32 = 2147483647; + var foo_14: i32 = -2147483648; + var foo_15: u32 = 4294967295u; + var foo_16: u32 = 0u; +} +)", {}, wgslEnv); } } diff --git a/tests/src/Tests/EntryFunctionTests.cpp b/tests/src/Tests/EntryFunctionTests.cpp index 77a66928..0d733370 100644 --- a/tests/src/Tests/EntryFunctionTests.cpp +++ b/tests/src/Tests/EntryFunctionTests.cpp @@ -109,6 +109,24 @@ fn main() -> FragOut OpStore %5 %17 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +struct FragOut +{ + @builtin(frag_depth) depth: f32 +} + +@fragment @early_depth_test(greater_equal) +fn main() -> FragOut +{ + var output: FragOut; + output.depth = 1.0; + return output; +} +)", {}, wgslEnv); } WHEN("Using depth_write(less)") @@ -212,6 +230,24 @@ fn main() -> FragOut OpStore %5 %17 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +struct FragOut +{ + @builtin(frag_depth) depth: f32 +} + +@fragment @early_depth_test(less_equal) +fn main() -> FragOut +{ + var output: FragOut; + output.depth = 0.0; + return output; +} +)", {}, wgslEnv); } WHEN("Using depth_write(replace)") @@ -319,6 +355,21 @@ fn main() -> FragOut OpStore %5 %17 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +struct FragOut +{ + @builtin(frag_depth) depth: f32 +} + +@fragment +fn main() -> FragOut +{ + var output: FragOut; + output.depth = 0.5; + return output; +} +)"); } WHEN("Using depth_write(unchanged)") @@ -447,6 +498,29 @@ fn main(input: FragIn) -> FragOut OpStore %13 %28 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +struct FragIn +{ + @builtin(position) fragCoord: vec4 +} + +struct FragOut +{ + @builtin(frag_depth) depth: f32 +} + +@fragment @early_depth_test(unchanged) +fn main(input: FragIn) -> FragOut +{ + var output: FragOut; + output.depth = input.fragCoord.z; + return output; +} +)", {}, wgslEnv); } } @@ -590,7 +664,19 @@ fn main() %4 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment @early_depth_test(force) +fn main() +{ + +} +)", {}, wgslEnv); } + WHEN("Disabling early fragment tests") { @@ -645,6 +731,14 @@ fn main() %4 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + +} +)"); } } } diff --git a/tests/src/Tests/ExternalTests.cpp b/tests/src/Tests/ExternalTests.cpp index 1ea89ea0..00c41e95 100644 --- a/tests/src/Tests/ExternalTests.cpp +++ b/tests/src/Tests/ExternalTests.cpp @@ -85,8 +85,20 @@ fn main() OpStore %14 %17 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +// Tag: Color map +@group(0) @binding(0) var tex: texture_2d; +@group(0) @binding(1) var texSampler: sampler; + +@fragment +fn main() +{ + var value: vec4 = textureSample(tex, texSampler, vec2(0.0, 0.0)); +} +)"); } - + SECTION("Arrays of texture") { std::string_view nzslSource = R"( @@ -170,6 +182,16 @@ fn main() OpStore %19 %24 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@group(0) @binding(0) var tex: binding_array, 5>; +@group(0) @binding(1) var texSampler: sampler; + +@fragment +fn main() +{ + var value: vec4 = textureSample(tex[2], texSampler, vec3(0.0, 0.0, 0.0)); +})"); } SECTION("Uniform buffers") @@ -285,6 +307,35 @@ fn main() OpStore %23 %30 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +struct f32_stride16 +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// Tag: DataStruct + +// std140 layout +struct Data +{ + // Tag: Values + values: array, + matrices: array, 3>, + _padding0: f32 +} + +@group(0) @binding(0) var data: Data; + +@fragment +fn main() +{ + var value: mat4x4 = data.values[42].value * data.matrices[1]; +} +)"); } SECTION("Storage buffers") @@ -441,6 +492,31 @@ fn main() %24 = OpLabel OpReturn OpFunctionEnd)", {}, spirvEnv, true); + + ExpectWGSL(*shaderModule, R"( +struct Data +{ + values: array +} + +@group(0) @binding(0) var inData: Data; +@group(0) @binding(1) var outData: Data; + +@fragment +fn main() +{ + { + var i: i32 = 0; + var _nzsl_to: i32 = 47; + while (i < _nzsl_to) + { + outData.values[i] = inData.values[i]; + i += 1; + } + + } + +})"); } WHEN("Generating SPIR-V 1.3") @@ -678,6 +754,23 @@ fn main() OpReturn OpFunctionEnd)", {}, spirvEnv, true); } + + ExpectWGSL(*shaderModule, R"( +struct Data +{ + data: u32, + values: array +} + +@group(0) @binding(0) var data: Data; + +@fragment +fn main() +{ + var value: f32 = data.values[42]; + var size: u32 = arrayLength(&data.values); +} +)"); } } @@ -785,6 +878,12 @@ fn main() nzsl::SpirvWriter spirvWriter; CHECK_THROWS_WITH(spirvWriter.Generate(*shaderModule), "unsupported type used in external block (SPIR-V doesn't allow primitive types as uniforms)"); } + + WHEN("Generating WGSL (which doesn't support primitive externals)") + { + nzsl::WgslWriter wgslWriter; + CHECK_THROWS_WITH(wgslWriter.Generate(*shaderModule), "primitive externals have no way to be translated in WGSL"); + } } @@ -1034,6 +1133,28 @@ fn main() %8 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +// std140 layout +struct Data +{ + index: i32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +var data: Data; + +@fragment +fn main() +{ + +} +)", {}, wgslEnv); } @@ -1284,6 +1405,71 @@ fn main() %48 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +struct f32_stride16 +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// std140 layout +struct DirectionalLight +{ + color: vec3, + direction: vec3, + invShadowMapSize: vec2, + ambientFactor: f32, + diffuseFactor: f32, + cascadeCount: u32, + _padding0: f32, + _padding1: f32, + _padding2: f32, + cascadeDistances: array, + viewProjMatrices: array, 4>, + _padding3: f32, + _padding4: f32 +} + +// std140 layout +struct LightData +{ + directionalLights: array, + directionalLightCount: u32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +@group(0) @binding(0) var lightData: LightData; + +@fragment +fn main() +{ + { + var lightIndex: u32 = 0u; + var _nzsl_to: u32 = lightData.directionalLightCount; + while (lightIndex < _nzsl_to) + { + var light: DirectionalLight; + light.color = lightData.directionalLights[lightIndex].color; + light.direction = lightData.directionalLights[lightIndex].direction; + light.invShadowMapSize = lightData.directionalLights[lightIndex].invShadowMapSize; + light.ambientFactor = lightData.directionalLights[lightIndex].ambientFactor; + light.diffuseFactor = lightData.directionalLights[lightIndex].diffuseFactor; + light.cascadeCount = lightData.directionalLights[lightIndex].cascadeCount; + light.cascadeDistances = lightData.directionalLights[lightIndex].cascadeDistances; + light.cascadeDistances = lightData.directionalLights[lightIndex].cascadeDistances; + var lightCopy: DirectionalLight = light; + lightIndex += 1u; + } + + } + +} +)"); } SECTION("named external") @@ -1380,6 +1566,25 @@ fn main() OpStore %19 %26 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +// std140 layout +struct Data +{ + color: vec4 +} + +// Tag: Color map +@group(0) @binding(0) var Instance_tex: texture_2d; +@group(0) @binding(1) var Instance_texSampler: sampler; +@group(0) @binding(2) var Instance_data: Data; + +@fragment +fn main() +{ + var value: vec4 = (textureSample(Instance_tex, Instance_texSampler, vec2(0.0, 0.0))) * Instance_data.color; +} +)"); } SECTION("named external shadowing") @@ -1471,6 +1676,19 @@ fn main() OpStore %13 %14 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +// Tag: Color map +@group(0) @binding(0) var Viewer_tex: texture_2d; +@group(0) @binding(1) var Viewer_texSampler: sampler; + +@fragment +fn main() +{ + var Viewer_tex: f32 = 0.0; + var value: f32 = Viewer_tex; +} +)"); } SECTION("Proper SPIR-V 1.4 generation") @@ -1935,5 +2153,76 @@ fn main(input: VertIn) -> VertOut OpStore %40 %87 OpReturn OpFunctionEnd)", {}, env, true); + + ExpectWGSL(*shaderModule, R"( +// std140 layout +struct MaterialData +{ + color: vec4 +} + +// std140 layout +struct InstanceData +{ + worldViewProjMat: mat4x4, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +@group(0) @binding(0) var tex: texture_2d; +@group(0) @binding(1) var texSampler: sampler; +@group(0) @binding(2) var instanceData: InstanceData; +@group(0) @binding(3) var materialData: MaterialData; + +struct VertIn +{ + @location(0) pos: vec3, + @location(1) uv: vec2 +} + +struct VertOut +{ + @builtin(position) pos: vec4, + @location(0) uv: vec2 +} + +struct FragOut +{ + @location(0) color: vec4 +} + +fn GetBaseColor() -> vec4 +{ + return materialData.color; +} + +fn GetWorldMatrix() -> mat4x4 +{ + return Intermediate(); +} + +fn Intermediate() -> mat4x4 +{ + return instanceData.worldViewProjMat; +} + +@fragment +fn main(input: VertOut) -> FragOut +{ + var output: FragOut; + output.color = (GetBaseColor()) * (textureSample(tex, texSampler, input.uv)); + return output; +} + +@vertex +fn main_2(input: VertIn) -> VertOut +{ + var output: VertOut; + output.pos = (GetWorldMatrix()) * (vec4(input.pos, 1.0)); + output.uv = input.uv; + return output; +} +)"); } } diff --git a/tests/src/Tests/FilesystemResolverTests.cpp b/tests/src/Tests/FilesystemResolverTests.cpp index 711b7337..89348957 100644 --- a/tests/src/Tests/FilesystemResolverTests.cpp +++ b/tests/src/Tests/FilesystemResolverTests.cpp @@ -219,4 +219,66 @@ fn main() -> Output OpStore %25 %57 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +// This file was generated by NZSL compiler (Nazara Shading Language) +// Author "SirLynix" +// Description: "Test module" +// License: "MIT" + +// Author "SirLynix" +// Description: "Test color module" +// License: "MIT" + +// Module Color + +@group(0) @binding(0) var _Color_tex1: texture_2d; +@group(0) @binding(1) var _Color_tex1Sampler: sampler; + +fn _Color_GenerateColor() -> vec4 +{ + return textureSample(_Color_tex1, _Color_tex1Sampler, vec2(0.0, 0.0)); +} + +fn _Color_GetColor() -> vec4 +{ + return _Color_GenerateColor(); +} +// Module DataStruct +struct _DataStruct_Data +{ + color: vec4 +} +// Module OutputStruct + +fn _OutputStruct_GetColorFromData(data: _DataStruct_Data) -> vec4 +{ + return data.color * vec4(0.5, 0.5, 0.5, 1.0); +} + +struct _OutputStruct_Output +{ + @location(0) color: vec4 +} + +// std140 layout +struct PushConstants +{ + color: vec4 +} + +var ExternalResources_constants: PushConstants; + +@fragment +fn main() -> _OutputStruct_Output +{ + var data: _DataStruct_Data; + data.color = _Color_GetColor(); + var output: _OutputStruct_Output; + output.color = (_OutputStruct_GetColorFromData(data)) * ExternalResources_constants.color; + return output; +})", {}, wgslEnv); } diff --git a/tests/src/Tests/FunctionsTests.cpp b/tests/src/Tests/FunctionsTests.cpp index 22d0ab1a..cc7a6735 100644 --- a/tests/src/Tests/FunctionsTests.cpp +++ b/tests/src/Tests/FunctionsTests.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include TEST_CASE("functions", "[Shader]") @@ -14,12 +15,12 @@ module; struct FragOut { - [location(0)] value: f32 + [location(0)] value: vec4[f32] } -fn GetValue() -> f32 +fn GetValue() -> vec4[f32] { - return 42.0; + return vec4[f32](42.0, 42.0, 42.0, 1.0); } [entry(frag)] @@ -36,13 +37,13 @@ fn main() -> FragOut ResolveModule(*shaderModule); ExpectGLSL(*shaderModule, R"( -float GetValue() +vec4 GetValue() { - return 42.0; + return vec4(42.0, 42.0, 42.0, 1.0); } /*************** Outputs ***************/ -layout(location = 0) out float _nzslOutvalue; +layout(location = 0) out vec4 _nzslOutvalue; void main() { @@ -55,9 +56,9 @@ void main() )"); ExpectNZSL(*shaderModule, R"( -fn GetValue() -> f32 +fn GetValue() -> vec4[f32] { - return 42.0; + return vec4[f32](42.0, 42.0, 42.0, 1.0); } [entry(frag)] @@ -72,10 +73,6 @@ fn main() -> FragOut ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel -OpReturnValue -OpFunctionEnd -OpFunction -OpLabel OpVariable OpFunctionCall OpFNegate @@ -86,6 +83,25 @@ OpCompositeExtract OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +struct FragOut +{ + @location(0) value: vec4 +} + +fn GetValue() -> vec4 +{ + return vec4(42.0, 42.0, 42.0, 1.0); +} + +@fragment +fn main() -> FragOut +{ + var output: FragOut; + output.value = -GetValue(); + return output; +})"); } SECTION("Unordered functions") @@ -96,12 +112,12 @@ module; struct FragOut { - [location(0)] value: f32 + [location(0)] value: vec4[f32] } -fn bar() -> f32 +fn bar() -> vec4[f32] { - return 42.0; + return vec4[f32](42.0, 42.0, 42.0, 1.0); } [entry(frag)] @@ -113,12 +129,12 @@ fn main() -> FragOut return output; } -fn baz() -> f32 +fn baz() -> vec4[f32] { return foo(); } -fn foo() -> f32 +fn foo() -> vec4[f32] { return bar(); } @@ -128,15 +144,15 @@ fn foo() -> f32 ResolveModule(*shaderModule); ExpectGLSL(*shaderModule, R"( -float bar() +vec4 bar() { - return 42.0; + return vec4(42.0, 42.0, 42.0, 1.0); } -float baz(); +vec4 baz(); /*************** Outputs ***************/ -layout(location = 0) out float _nzslOutvalue; +layout(location = 0) out vec4 _nzslOutvalue; void main() { @@ -147,23 +163,23 @@ void main() return; } -float foo(); +vec4 foo(); -float baz() +vec4 baz() { return foo(); } -float foo() +vec4 foo() { return bar(); } )"); ExpectNZSL(*shaderModule, R"( -fn bar() -> f32 +fn bar() -> vec4[f32] { - return 42.0; + return vec4[f32](42.0, 42.0, 42.0, 1.0); } [entry(frag)] @@ -174,12 +190,12 @@ fn main() -> FragOut return output; } -fn baz() -> f32 +fn baz() -> vec4[f32] { return foo(); } -fn foo() -> f32 +fn foo() -> vec4[f32] { return bar(); } @@ -188,10 +204,6 @@ fn foo() -> f32 ExpectSPIRV(*shaderModule, R"( OpFunction OpLabel -OpReturnValue -OpFunctionEnd -OpFunction -OpLabel OpVariable OpFunctionCall OpAccessChain @@ -211,6 +223,31 @@ OpLabel OpFunctionCall OpReturnValue OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +fn bar() -> vec4 +{ + return vec4(42.0, 42.0, 42.0, 1.0); +} + +@fragment +fn main() -> FragOut +{ + var output: FragOut; + output.value = baz(); + return output; +} + +fn baz() -> vec4 +{ + return foo(); +} + +fn foo() -> vec4 +{ + return bar(); +} +)"); } SECTION("inout function call") @@ -362,6 +399,26 @@ fn main() -> FragOut OpStore %12 %50 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +fn Half(color: ptr>, value: ptr, inValue: f32, inValue2: f32) +{ + *color *= 2.0; + *value = 10.0; +} + +@fragment +fn main() -> FragOut +{ + var output: FragOut; + var mainColor: vec3 = vec3(1.0, 1.0, 1.0); + var inValue: f32 = 2.0; + var inValue2: f32 = 1.0; + Half(&mainColor, &output.value2, inValue, inValue2); + output.value = mainColor.x; + return output; +} +)"); } SECTION("passing sampler to function") @@ -489,6 +546,29 @@ fn main() -> FragOut OpStore %14 %33 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +fn sample_center(tex: texture_2d, texSampler: sampler) -> vec4 +{ + return textureSample(tex, texSampler, vec2(0.5, 0.5)); +} + +@group(0) @binding(0) var ExtData_texture: texture_2d; +@group(0) @binding(1) var ExtData_textureSampler: sampler; + +struct FragOut +{ + @location(0) value: vec4 +} + +@fragment +fn main() -> FragOut +{ + var output: FragOut; + output.value = sample_center(ExtData_texture, ExtData_textureSampler); + return output; +} +)"); } SECTION("passing sampler array to function") @@ -622,5 +702,8 @@ fn main() -> FragOut OpStore %19 %39 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter wgslWriter; + CHECK_THROWS_WITH(wgslWriter.Generate(*shaderModule), "WGSL does not support sampled texture array as funtion parameter"); } } diff --git a/tests/src/Tests/IdentifierTests.cpp b/tests/src/Tests/IdentifierTests.cpp index cdf80ff0..46e7903f 100644 --- a/tests/src/Tests/IdentifierTests.cpp +++ b/tests/src/Tests/IdentifierTests.cpp @@ -25,10 +25,10 @@ fn int() -> i32 struct output { - active: vec3[f32], - active_: vec2[i32], - _nzsl: i32, - _: f32 + [location(0)] active: vec3[f32], + [location(1)] active_: vec2[i32], + [location(2)] _nzsl: i32, + [location(3)] _: f32 } [entry(frag)] @@ -66,10 +66,10 @@ struct output_ }; /*************** Outputs ***************/ -out vec3 _nzslOutactive_; -out ivec2 _nzslOutactive2_2; -out int _nzslOut_; -out float _nzslOut_2_2; +layout(location = 0) out vec3 _nzslOutactive_; +layout(location = 1) out ivec2 _nzslOutactive2_2; +layout(location = 2) out int _nzslOut_; +layout(location = 3) out float _nzslOut_2_2; void main() { @@ -104,10 +104,10 @@ fn int() -> i32 struct output { - active: vec3[f32], - active_: vec2[i32], - _nzsl: i32, - _: f32 + [location(0)] active: vec3[f32], + [location(1)] active_: vec2[i32], + [location(2)] _nzsl: i32, + [location(3)] _: f32 } [entry(frag)] @@ -145,7 +145,45 @@ OpCompositeConstruct OpAccessChain OpStore OpLoad +OpCompositeExtract +OpStore +OpCompositeExtract +OpStore +OpCompositeExtract +OpStore +OpCompositeExtract +OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@group(0) @binding(0) var texture: texture_2d; +@group(0) @binding(1) var textureSampler: sampler; + +fn int() -> i32 +{ + return 42; +} + +struct output +{ + @location(0) active_: vec3, + @location(1) active2_2: vec2, + @location(2) _2_2: i32, + @location(3) _2_2_2: f32 +} + +@fragment +fn main() -> output +{ + var input: i32 = int(); + var input_: i32 = 0; + var fl2_oa8_t: f32 = 42.0; + var outValue: output; + var _nzsl_cachedResult: f32 = (f32(input)) + fl2_oa8_t; + outValue.active_ = vec3(_nzsl_cachedResult, _nzsl_cachedResult, _nzsl_cachedResult); + return outValue; +} +)"); } } diff --git a/tests/src/Tests/ImplicitTests.cpp b/tests/src/Tests/ImplicitTests.cpp index a1ef4c8c..0c9ae996 100644 --- a/tests/src/Tests/ImplicitTests.cpp +++ b/tests/src/Tests/ImplicitTests.cpp @@ -60,6 +60,15 @@ fn foo() OpStore %10 %14 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn foo() +{ + var x: f32; + var v: vec3 = vec3(x, x, x); +} +)"); } SECTION("Implicit arrays") @@ -213,6 +222,39 @@ fn foo() OpStore %61 %64 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +const vertPos: array, 3> = array, 3>( + vec2(-1.0, 1.0), + vec2(-1.0, -3.0), + vec2(3.0, 1.0) +); + +const a: array, 3> = array, 3>( + vec3(1, 2, 3), + vec3(4, 5, 6), + vec3(7, 8, 9) +); + +const b: array, 3> = array, 3>( + vec2(1.0, 2.0), + vec2(3.0, 4.0), + vec2(5.0, 6.0) +); + +const c: array = array( + true, + false, + false +); + +@fragment +fn foo() +{ + var value: vec3 = vec3(-1, -3, 42); + var runtimeArray: array, 3> = array, 3>(value, value, vec3(1, 2, 3)); +} +)"); } SECTION("Implicit matrices") @@ -388,5 +430,37 @@ fn foo() OpStore %52 %96 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment +fn foo() +{ + var x: f32 = 1.0; + var v: vec3 = vec3(-2.0, -1.0, 0.0); + var _nzsl_matrix: mat4x4; + _nzsl_matrix[0u] = vec4(x, 0.0, 0.0, 0.0); + _nzsl_matrix[1u] = vec4(0.0, x, 0.0, 0.0); + _nzsl_matrix[2u] = vec4(0.0, 0.0, x, 0.0); + _nzsl_matrix[3u] = vec4(0.0, 0.0, 0.0, x); + var m1: mat4x4 = _nzsl_matrix; + var _nzsl_matrix_2: mat3x3; + _nzsl_matrix_2[0u] = m1[0u].xyz; + _nzsl_matrix_2[1u] = m1[1u].xyz; + _nzsl_matrix_2[2u] = m1[2u].xyz; + var m2: mat3x3 = _nzsl_matrix_2; + var _nzsl_matrix_3: mat2x2; + _nzsl_matrix_3[0u] = vec2(x, x); + _nzsl_matrix_3[1u] = vec2(x, x); + var m3: mat2x2 = _nzsl_matrix_3; + var _nzsl_matrix_4: mat3x3; + _nzsl_matrix_4[0u] = v; + _nzsl_matrix_4[1u] = vec3(1.0, 2.0, 3.0); + _nzsl_matrix_4[2u] = vec3(4.0, 5.0, 6.0); + var m4: mat3x3 = _nzsl_matrix_4; +} +)", {}, wgslEnv); } } diff --git a/tests/src/Tests/InputOutputTests.cpp b/tests/src/Tests/InputOutputTests.cpp index c0f0f9f6..35031a56 100644 --- a/tests/src/Tests/InputOutputTests.cpp +++ b/tests/src/Tests/InputOutputTests.cpp @@ -25,16 +25,17 @@ external struct VertIn { - [builtin(instance_index)] instance_index: i32, - [builtin(draw_index)] draw_index: i32, - [builtin(vertex_index)] vertex_index: i32 + [builtin(instance_index)] instance_index: u32, + [builtin(draw_index)] draw_index: u32, + [builtin(vertex_index)] vertex_index: u32 } struct VertOut { - [location(0), interp(flat)] instance_index: i32, + [location(0), interp(flat)] instance_index: u32, [location(1), interp(no_perspective)] x: f32, [location(2), interp(smooth)] y: f32, + [builtin(position)] position: vec4[f32], } struct FragOut @@ -58,6 +59,7 @@ fn main(input: VertIn) -> VertOut output.instance_index = input.instance_index; output.x = f32(input.draw_index); output.y = f32(input.vertex_index); + output.position = vec4[f32](0.0, 0.0, 0.0, 1.0); return output; } @@ -79,16 +81,17 @@ layout(std430) buffer _nzslBindingdata struct VertIn { - int instance_index; - int draw_index; - int vertex_index; + uint instance_index; + uint draw_index; + uint vertex_index; }; struct VertOut { - int instance_index; + uint instance_index; float x; float y; + vec4 position; }; struct FragOut @@ -97,25 +100,27 @@ struct FragOut }; /*************** Outputs ***************/ -layout(location = 0) flat out int _nzslOutinstance_index; +layout(location = 0) flat out uint _nzslOutinstance_index; layout(location = 1) noperspective out float _nzslOutx; layout(location = 2) smooth out float _nzslOuty; void main() { VertIn input_; - input_.instance_index = (gl_BaseInstance + gl_InstanceID); - input_.draw_index = gl_DrawID; - input_.vertex_index = gl_VertexID; + input_.instance_index = uint(gl_BaseInstance) + uint(gl_InstanceID); + input_.draw_index = uint(gl_DrawID); + input_.vertex_index = uint(gl_VertexID); VertOut output_; output_.instance_index = input_.instance_index; output_.x = float(input_.draw_index); output_.y = float(input_.vertex_index); + output_.position = vec4(0.0, 0.0, 0.0, 1.0); _nzslOutinstance_index = output_.instance_index; _nzslOutx = output_.x; _nzslOuty = output_.y; + gl_Position = output_.position; return; } )", {}, glslEnv); @@ -128,16 +133,17 @@ layout(std430) buffer _nzslBindingdata struct VertIn { - int instance_index; - int draw_index; - int vertex_index; + uint instance_index; + uint draw_index; + uint vertex_index; }; struct VertOut { - int instance_index; + uint instance_index; float x; float y; + vec4 position; }; struct FragOut @@ -146,7 +152,7 @@ struct FragOut }; /**************** Inputs ****************/ -layout(location = 0) flat in int _nzslIninstance_index; +layout(location = 0) flat in uint _nzslIninstance_index; layout(location = 1) noperspective in float _nzslInx; layout(location = 2) smooth in float _nzslIny; @@ -185,16 +191,17 @@ external struct VertIn { - [builtin(instance_index)] instance_index: i32, - [builtin(draw_index)] draw_index: i32, - [builtin(vertex_index)] vertex_index: i32 + [builtin(instance_index)] instance_index: u32, + [builtin(draw_index)] draw_index: u32, + [builtin(vertex_index)] vertex_index: u32 } struct VertOut { - [location(0), interp(flat)] instance_index: i32, + [location(0), interp(flat)] instance_index: u32, [location(1), interp(no_perspective)] x: f32, - [location(2), interp(smooth)] y: f32 + [location(2), interp(smooth)] y: f32, + [builtin(position)] position: vec4[f32] } struct FragOut @@ -217,6 +224,7 @@ fn main(input: VertIn) -> VertOut output.instance_index = input.instance_index; output.x = f32(input.draw_index); output.y = f32(input.vertex_index); + output.position = vec4[f32](0.0, 0.0, 0.0, 1.0); return output; } )"); @@ -226,65 +234,68 @@ fn main(input: VertIn) -> VertOut spirvEnv.spvMinorVersion = 3; ExpectSPIRV(*shaderModule, R"( - OpCapability Capability(Shader) OpCapability Capability(DrawParameters) OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) - OpEntryPoint ExecutionModel(Fragment) %37 "main" %11 %15 %18 %23 - OpEntryPoint ExecutionModel(Vertex) %38 "main" %27 %28 %29 %33 %35 %36 - OpExecutionMode %37 ExecutionMode(OriginUpperLeft) + OpEntryPoint ExecutionModel(Fragment) %42 "main" %11 %16 %19 %24 + OpEntryPoint ExecutionModel(Vertex) %43 "main" %28 %29 %30 %34 %36 %37 %38 + OpExecutionMode %42 ExecutionMode(OriginUpperLeft) OpSource SourceLanguage(NZSL) 4198400 OpSourceExtension "Version: 1.1" OpName %4 "ColorData" OpMemberName %4 0 "colors" - OpName %20 "VertOut" - OpMemberName %20 0 "instance_index" - OpMemberName %20 1 "x" - OpMemberName %20 2 "y" - OpName %24 "FragOut" - OpMemberName %24 0 "color" - OpName %30 "VertIn" - OpMemberName %30 0 "instance_index" - OpMemberName %30 1 "draw_index" - OpMemberName %30 2 "vertex_index" + OpName %21 "VertOut" + OpMemberName %21 0 "instance_index" + OpMemberName %21 1 "x" + OpMemberName %21 2 "y" + OpMemberName %21 3 "position" + OpName %25 "FragOut" + OpMemberName %25 0 "color" + OpName %31 "VertIn" + OpMemberName %31 0 "instance_index" + OpMemberName %31 1 "draw_index" + OpMemberName %31 2 "vertex_index" OpName %6 "data" OpName %11 "instance_index" - OpName %15 "x" - OpName %18 "y" - OpName %23 "color" - OpName %27 "instance_index" - OpName %28 "draw_index" - OpName %29 "vertex_index" - OpName %33 "instance_index" - OpName %35 "x" - OpName %36 "y" - OpName %37 "main" - OpName %38 "main" + OpName %16 "x" + OpName %19 "y" + OpName %24 "color" + OpName %28 "instance_index" + OpName %29 "draw_index" + OpName %30 "vertex_index" + OpName %34 "instance_index" + OpName %36 "x" + OpName %37 "y" + OpName %38 "position" + OpName %42 "main" + OpName %43 "main" OpDecorate %6 Decoration(Binding) 0 OpDecorate %6 Decoration(DescriptorSet) 0 - OpDecorate %27 Decoration(BuiltIn) BuiltIn(InstanceIndex) - OpDecorate %28 Decoration(BuiltIn) BuiltIn(DrawIndex) - OpDecorate %29 Decoration(BuiltIn) BuiltIn(VertexIndex) + OpDecorate %28 Decoration(BuiltIn) BuiltIn(InstanceIndex) + OpDecorate %29 Decoration(BuiltIn) BuiltIn(DrawIndex) + OpDecorate %30 Decoration(BuiltIn) BuiltIn(VertexIndex) + OpDecorate %38 Decoration(BuiltIn) BuiltIn(Position) OpDecorate %11 Decoration(Location) 0 - OpDecorate %15 Decoration(Location) 1 - OpDecorate %18 Decoration(Location) 2 - OpDecorate %23 Decoration(Location) 0 - OpDecorate %33 Decoration(Location) 0 - OpDecorate %35 Decoration(Location) 1 - OpDecorate %36 Decoration(Location) 2 + OpDecorate %16 Decoration(Location) 1 + OpDecorate %19 Decoration(Location) 2 + OpDecorate %24 Decoration(Location) 0 + OpDecorate %34 Decoration(Location) 0 + OpDecorate %36 Decoration(Location) 1 + OpDecorate %37 Decoration(Location) 2 OpDecorate %11 Decoration(Flat) - OpDecorate %15 Decoration(NoPerspective) - OpDecorate %33 Decoration(Flat) - OpDecorate %35 Decoration(NoPerspective) + OpDecorate %16 Decoration(NoPerspective) + OpDecorate %34 Decoration(Flat) + OpDecorate %36 Decoration(NoPerspective) OpDecorate %3 Decoration(ArrayStride) 16 OpDecorate %4 Decoration(Block) OpMemberDecorate %4 0 Decoration(Offset) 0 - OpMemberDecorate %20 0 Decoration(Offset) 0 - OpMemberDecorate %20 1 Decoration(Offset) 4 - OpMemberDecorate %20 2 Decoration(Offset) 8 - OpMemberDecorate %24 0 Decoration(Offset) 0 - OpMemberDecorate %30 0 Decoration(Offset) 0 - OpMemberDecorate %30 1 Decoration(Offset) 4 - OpMemberDecorate %30 2 Decoration(Offset) 8 + OpMemberDecorate %21 0 Decoration(Offset) 0 + OpMemberDecorate %21 1 Decoration(Offset) 4 + OpMemberDecorate %21 2 Decoration(Offset) 8 + OpMemberDecorate %21 3 Decoration(Offset) 16 + OpMemberDecorate %25 0 Decoration(Offset) 0 + OpMemberDecorate %31 0 Decoration(Offset) 0 + OpMemberDecorate %31 1 Decoration(Offset) 4 + OpMemberDecorate %31 2 Decoration(Offset) 8 %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeRuntimeArray %2 @@ -292,96 +303,163 @@ fn main(input: VertIn) -> VertOut %5 = OpTypePointer StorageClass(StorageBuffer) %4 %7 = OpTypeVoid %8 = OpTypeFunction %7 - %9 = OpTypeInt 32 1 + %9 = OpTypeInt 32 0 %10 = OpTypePointer StorageClass(Input) %9 -%12 = OpConstant %9 i32(0) -%13 = OpTypePointer StorageClass(Function) %9 -%14 = OpTypePointer StorageClass(Input) %1 -%16 = OpConstant %9 i32(1) -%17 = OpTypePointer StorageClass(Function) %1 -%19 = OpConstant %9 i32(2) -%20 = OpTypeStruct %9 %1 %1 -%21 = OpTypePointer StorageClass(Function) %20 -%22 = OpTypePointer StorageClass(Output) %2 -%24 = OpTypeStruct %2 -%25 = OpTypePointer StorageClass(Function) %24 -%26 = OpTypeRuntimeArray %2 -%30 = OpTypeStruct %9 %9 %9 -%31 = OpTypePointer StorageClass(Function) %30 -%32 = OpTypePointer StorageClass(Output) %9 -%34 = OpTypePointer StorageClass(Output) %1 -%47 = OpTypePointer StorageClass(StorageBuffer) %2 -%57 = OpTypePointer StorageClass(Function) %2 +%12 = OpTypeInt 32 1 +%13 = OpConstant %12 i32(0) +%14 = OpTypePointer StorageClass(Function) %9 +%15 = OpTypePointer StorageClass(Input) %1 +%17 = OpConstant %12 i32(1) +%18 = OpTypePointer StorageClass(Function) %1 +%20 = OpConstant %12 i32(2) +%21 = OpTypeStruct %9 %1 %1 %2 +%22 = OpTypePointer StorageClass(Function) %21 +%23 = OpTypePointer StorageClass(Output) %2 +%25 = OpTypeStruct %2 +%26 = OpTypePointer StorageClass(Function) %25 +%27 = OpTypeRuntimeArray %2 +%31 = OpTypeStruct %9 %9 %9 +%32 = OpTypePointer StorageClass(Function) %31 +%33 = OpTypePointer StorageClass(Output) %9 +%35 = OpTypePointer StorageClass(Output) %1 +%39 = OpConstant %12 i32(3) +%40 = OpConstant %1 f32(0) +%41 = OpConstant %1 f32(1) +%52 = OpTypePointer StorageClass(StorageBuffer) %2 +%62 = OpTypePointer StorageClass(Function) %2 %6 = OpVariable %5 StorageClass(StorageBuffer) %11 = OpVariable %10 StorageClass(Input) -%15 = OpVariable %14 StorageClass(Input) -%18 = OpVariable %14 StorageClass(Input) -%23 = OpVariable %22 StorageClass(Output) -%27 = OpVariable %10 StorageClass(Input) +%16 = OpVariable %15 StorageClass(Input) +%19 = OpVariable %15 StorageClass(Input) +%24 = OpVariable %23 StorageClass(Output) %28 = OpVariable %10 StorageClass(Input) %29 = OpVariable %10 StorageClass(Input) -%33 = OpVariable %32 StorageClass(Output) -%35 = OpVariable %34 StorageClass(Output) -%36 = OpVariable %34 StorageClass(Output) -%37 = OpFunction %7 FunctionControl(0) %8 -%39 = OpLabel -%40 = OpVariable %25 StorageClass(Function) -%41 = OpVariable %21 StorageClass(Function) -%42 = OpAccessChain %13 %41 %12 - OpCopyMemory %42 %11 -%43 = OpAccessChain %17 %41 %16 - OpCopyMemory %43 %15 -%44 = OpAccessChain %17 %41 %19 - OpCopyMemory %44 %18 -%45 = OpAccessChain %13 %41 %12 -%46 = OpLoad %9 %45 -%48 = OpAccessChain %47 %6 %12 %46 -%49 = OpLoad %2 %48 -%50 = OpAccessChain %17 %41 %16 -%51 = OpLoad %1 %50 -%52 = OpVectorTimesScalar %2 %49 %51 -%53 = OpAccessChain %17 %41 %19 -%54 = OpLoad %1 %53 -%55 = OpVectorTimesScalar %2 %52 %54 -%56 = OpAccessChain %57 %40 %12 - OpStore %56 %55 -%58 = OpLoad %24 %40 -%59 = OpCompositeExtract %2 %58 0 - OpStore %23 %59 +%30 = OpVariable %10 StorageClass(Input) +%34 = OpVariable %33 StorageClass(Output) +%36 = OpVariable %35 StorageClass(Output) +%37 = OpVariable %35 StorageClass(Output) +%38 = OpVariable %23 StorageClass(Output) +%42 = OpFunction %7 FunctionControl(0) %8 +%44 = OpLabel +%45 = OpVariable %26 StorageClass(Function) +%46 = OpVariable %22 StorageClass(Function) +%47 = OpAccessChain %14 %46 %13 + OpCopyMemory %47 %11 +%48 = OpAccessChain %18 %46 %17 + OpCopyMemory %48 %16 +%49 = OpAccessChain %18 %46 %20 + OpCopyMemory %49 %19 +%50 = OpAccessChain %14 %46 %13 +%51 = OpLoad %9 %50 +%53 = OpAccessChain %52 %6 %13 %51 +%54 = OpLoad %2 %53 +%55 = OpAccessChain %18 %46 %17 +%56 = OpLoad %1 %55 +%57 = OpVectorTimesScalar %2 %54 %56 +%58 = OpAccessChain %18 %46 %20 +%59 = OpLoad %1 %58 +%60 = OpVectorTimesScalar %2 %57 %59 +%61 = OpAccessChain %62 %45 %13 + OpStore %61 %60 +%63 = OpLoad %25 %45 +%64 = OpCompositeExtract %2 %63 0 + OpStore %24 %64 OpReturn OpFunctionEnd -%38 = OpFunction %7 FunctionControl(0) %8 -%60 = OpLabel -%61 = OpVariable %21 StorageClass(Function) -%62 = OpVariable %31 StorageClass(Function) -%63 = OpAccessChain %13 %62 %12 - OpCopyMemory %63 %27 -%64 = OpAccessChain %13 %62 %16 - OpCopyMemory %64 %28 -%65 = OpAccessChain %13 %62 %19 - OpCopyMemory %65 %29 -%66 = OpAccessChain %13 %62 %12 -%67 = OpLoad %9 %66 -%68 = OpAccessChain %13 %61 %12 - OpStore %68 %67 -%69 = OpAccessChain %13 %62 %16 -%70 = OpLoad %9 %69 -%71 = OpConvertSToF %1 %70 -%72 = OpAccessChain %17 %61 %16 - OpStore %72 %71 -%73 = OpAccessChain %13 %62 %19 -%74 = OpLoad %9 %73 -%75 = OpConvertSToF %1 %74 -%76 = OpAccessChain %17 %61 %19 - OpStore %76 %75 -%77 = OpLoad %20 %61 -%78 = OpCompositeExtract %9 %77 0 - OpStore %33 %78 -%79 = OpCompositeExtract %1 %77 1 - OpStore %35 %79 -%80 = OpCompositeExtract %1 %77 2 - OpStore %36 %80 +%43 = OpFunction %7 FunctionControl(0) %8 +%65 = OpLabel +%66 = OpVariable %22 StorageClass(Function) +%67 = OpVariable %32 StorageClass(Function) +%68 = OpAccessChain %14 %67 %13 + OpCopyMemory %68 %28 +%69 = OpAccessChain %14 %67 %17 + OpCopyMemory %69 %29 +%70 = OpAccessChain %14 %67 %20 + OpCopyMemory %70 %30 +%71 = OpAccessChain %14 %67 %13 +%72 = OpLoad %9 %71 +%73 = OpAccessChain %14 %66 %13 + OpStore %73 %72 +%74 = OpAccessChain %14 %67 %17 +%75 = OpLoad %9 %74 +%76 = OpConvertUToF %1 %75 +%77 = OpAccessChain %18 %66 %17 + OpStore %77 %76 +%78 = OpAccessChain %14 %67 %20 +%79 = OpLoad %9 %78 +%80 = OpConvertUToF %1 %79 +%81 = OpAccessChain %18 %66 %20 + OpStore %81 %80 +%82 = OpCompositeConstruct %2 %40 %40 %40 %41 +%83 = OpAccessChain %62 %66 %39 + OpStore %83 %82 +%84 = OpLoad %21 %66 +%85 = OpCompositeExtract %9 %84 0 + OpStore %34 %85 +%86 = OpCompositeExtract %1 %84 1 + OpStore %36 %86 +%87 = OpCompositeExtract %1 %84 2 + OpStore %37 %87 +%88 = OpCompositeExtract %2 %84 3 + OpStore %38 %88 OpReturn OpFunctionEnd)", {}, spirvEnv, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +struct _nzslBuiltinEmulationStruct +{ + draw_index: u32, + +} +@group(0) @binding(0) var _nzslBuiltinEmulation: _nzslBuiltinEmulationStruct; + +struct ColorData +{ + colors: array> +} + +@group(0) @binding(1) var data: ColorData; + +struct VertIn +{ + @builtin(instance_index) instance_index: u32, + @builtin(vertex_index) vertex_index: u32 +} + +struct VertOut +{ + @location(0) @interpolate(flat) instance_index: u32, + @location(1) @interpolate(perspective) x: f32, + @location(2) @interpolate(linear) y: f32, + @builtin(position) position: vec4 +} + +struct FragOut +{ + @location(0) color: vec4 +} + +@fragment +fn main(input: VertOut) -> FragOut +{ + var output: FragOut; + output.color = (data.colors[input.instance_index] * input.x) * input.y; + return output; +} + +@vertex +fn main_2(input: VertIn) -> VertOut +{ + var output: VertOut; + output.instance_index = input.instance_index; + output.x = f32(_nzslBuiltinEmulation.draw_index); + output.y = f32(input.vertex_index); + output.position = vec4(0.0, 0.0, 0.0, 1.0); + return output; +} +)", {}, wgslEnv); } } diff --git a/tests/src/Tests/IntrinsicTests.cpp b/tests/src/Tests/IntrinsicTests.cpp index b2fd5d7f..02b19e86 100644 --- a/tests/src/Tests/IntrinsicTests.cpp +++ b/tests/src/Tests/IntrinsicTests.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include TEST_CASE("intrinsics", "[Shader]") @@ -110,6 +111,28 @@ fn main() OpStore %23 %25 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +// std140 layout +struct DataStruct +{ + _padding0: f32, + _padding1: f32, + _padding2: f32, + _padding3: f32, + values: array +} + +@group(0) @binding(0) var data: DataStruct; + +@fragment +fn main() +{ + var a: array = array(1.0, 2.0, 3.0); + var arraySize: u32 = 3; + var dynArraySize: u32 = arrayLength(&data.values); +} +)"); } WHEN("testing texture intrinsics") @@ -123,13 +146,10 @@ module; external { tex1D: sampler1D[f32], - tex1DArray: sampler1D_array[f32], tex2D: sampler2D[f32], tex2DArray: sampler2D_array[f32], tex3D: sampler3D[f32], texCube: sampler_cube[f32], - tex1DDepth: depth_sampler1D[f32], - tex1DArrayDepth: depth_sampler1D_array[f32], tex2DDepth: depth_sampler2D[f32], tex2DArrayDepth: depth_sampler2D_array[f32], texCubeDepth: depth_sampler_cube[f32], @@ -144,14 +164,11 @@ fn main() let uv3f = vec3[f32](0.0, 1.0, 2.0); let sampleResult1 = tex1D.Sample(uv1f); - let sampleResult2 = tex1DArray.Sample(uv2f); - let sampleResult3 = tex2D.Sample(uv2f); - let sampleResult4 = tex2DArray.Sample(uv3f); - let sampleResult5 = tex3D.Sample(uv3f); - let sampleResult6 = texCube.Sample(uv3f); - - let depthSampleResult1 = tex1DDepth.SampleDepthComp(uv1f, depth); - let depthSampleResult2 = tex1DArrayDepth.SampleDepthComp(uv2f, depth); + let sampleResult2 = tex2D.Sample(uv2f); + let sampleResult3 = tex2DArray.Sample(uv3f); + let sampleResult4 = tex3D.Sample(uv3f); + let sampleResult5 = texCube.Sample(uv3f); + let depthSampleResult3 = tex2DDepth.SampleDepthComp(uv2f, depth); let depthSampleResult4 = tex2DArrayDepth.SampleDepthComp(uv3f, depth); let depthSampleResult5 = texCubeDepth.SampleDepthComp(uv3f, depth); @@ -167,13 +184,10 @@ fn main() ExpectGLSL(*shaderModule, R"( uniform sampler1D tex1D; -uniform sampler1DArray tex1DArray; uniform sampler2D tex2D; uniform sampler2DArray tex2DArray; uniform sampler3D tex3D; uniform samplerCube texCube; -uniform sampler1DShadow tex1DDepth; -uniform sampler1DArrayShadow tex1DArrayDepth; uniform sampler2DShadow tex2DDepth; uniform sampler2DArrayShadow tex2DArrayDepth; uniform samplerCubeShadow texCubeDepth; @@ -185,13 +199,10 @@ void main() vec2 uv2f = vec2(0.0, 1.0); vec3 uv3f = vec3(0.0, 1.0, 2.0); vec4 sampleResult1 = texture(tex1D, uv1f); - vec4 sampleResult2 = texture(tex1DArray, uv2f); - vec4 sampleResult3 = texture(tex2D, uv2f); - vec4 sampleResult4 = texture(tex2DArray, uv3f); - vec4 sampleResult5 = texture(tex3D, uv3f); - vec4 sampleResult6 = texture(texCube, uv3f); - float depthSampleResult1 = texture(tex1DDepth, vec3(uv1f, 0.0, depth)); - float depthSampleResult2 = texture(tex1DArrayDepth, vec3(uv2f, depth)); + vec4 sampleResult2 = texture(tex2D, uv2f); + vec4 sampleResult3 = texture(tex2DArray, uv3f); + vec4 sampleResult4 = texture(tex3D, uv3f); + vec4 sampleResult5 = texture(texCube, uv3f); float depthSampleResult3 = texture(tex2DDepth, vec3(uv2f, depth)); float depthSampleResult4 = texture(tex2DArrayDepth, vec4(uv3f, depth)); float depthSampleResult5 = texture(texCubeDepth, vec4(uv3f, depth)); @@ -203,16 +214,13 @@ void main() external { [set(0), binding(0)] tex1D: sampler1D[f32], - [set(0), binding(1)] tex1DArray: sampler1D_array[f32], - [set(0), binding(2)] tex2D: sampler2D[f32], - [set(0), binding(3)] tex2DArray: sampler2D_array[f32], - [set(0), binding(4)] tex3D: sampler3D[f32], - [set(0), binding(5)] texCube: sampler_cube[f32], - [set(0), binding(6)] tex1DDepth: depth_sampler1D[f32], - [set(0), binding(7)] tex1DArrayDepth: depth_sampler1D_array[f32], - [set(0), binding(8)] tex2DDepth: depth_sampler2D[f32], - [set(0), binding(9)] tex2DArrayDepth: depth_sampler2D_array[f32], - [set(0), binding(10)] texCubeDepth: depth_sampler_cube[f32] + [set(0), binding(1)] tex2D: sampler2D[f32], + [set(0), binding(2)] tex2DArray: sampler2D_array[f32], + [set(0), binding(3)] tex3D: sampler3D[f32], + [set(0), binding(4)] texCube: sampler_cube[f32], + [set(0), binding(5)] tex2DDepth: depth_sampler2D[f32], + [set(0), binding(6)] tex2DArrayDepth: depth_sampler2D_array[f32], + [set(0), binding(7)] texCubeDepth: depth_sampler_cube[f32] } [entry(frag)] @@ -223,13 +231,10 @@ fn main() let uv2f: vec2[f32] = vec2[f32](0.0, 1.0); let uv3f: vec3[f32] = vec3[f32](0.0, 1.0, 2.0); let sampleResult1: vec4[f32] = tex1D.Sample(uv1f); - let sampleResult2: vec4[f32] = tex1DArray.Sample(uv2f); - let sampleResult3: vec4[f32] = tex2D.Sample(uv2f); - let sampleResult4: vec4[f32] = tex2DArray.Sample(uv3f); - let sampleResult5: vec4[f32] = tex3D.Sample(uv3f); - let sampleResult6: vec4[f32] = texCube.Sample(uv3f); - let depthSampleResult1: f32 = tex1DDepth.SampleDepthComp(uv1f, depth); - let depthSampleResult2: f32 = tex1DArrayDepth.SampleDepthComp(uv2f, depth); + let sampleResult2: vec4[f32] = tex2D.Sample(uv2f); + let sampleResult3: vec4[f32] = tex2DArray.Sample(uv3f); + let sampleResult4: vec4[f32] = tex3D.Sample(uv3f); + let sampleResult5: vec4[f32] = texCube.Sample(uv3f); let depthSampleResult3: f32 = tex2DDepth.SampleDepthComp(uv2f, depth); let depthSampleResult4: f32 = tex2DArrayDepth.SampleDepthComp(uv3f, depth); let depthSampleResult5: f32 = texCubeDepth.SampleDepthComp(uv3f, depth); @@ -237,179 +242,265 @@ fn main() )"); ExpectSPIRV(*shaderModule, R"( - OpCapability Capability(Shader) - OpCapability Capability(Sampled1D) - OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) - OpEntryPoint ExecutionModel(Fragment) %59 "main" - OpExecutionMode %59 ExecutionMode(OriginUpperLeft) - OpSource SourceLanguage(NZSL) 4198400 - OpSourceExtension "Version: 1.1" - OpName %5 "tex1D" - OpName %9 "tex1DArray" - OpName %13 "tex2D" - OpName %17 "tex2DArray" - OpName %21 "tex3D" - OpName %25 "texCube" - OpName %29 "tex1DDepth" - OpName %33 "tex1DArrayDepth" - OpName %37 "tex2DDepth" - OpName %41 "tex2DArrayDepth" - OpName %45 "texCubeDepth" - OpName %59 "main" - OpDecorate %5 Decoration(Binding) 0 - OpDecorate %5 Decoration(DescriptorSet) 0 - OpDecorate %9 Decoration(Binding) 1 - OpDecorate %9 Decoration(DescriptorSet) 0 - OpDecorate %13 Decoration(Binding) 2 - OpDecorate %13 Decoration(DescriptorSet) 0 - OpDecorate %17 Decoration(Binding) 3 - OpDecorate %17 Decoration(DescriptorSet) 0 - OpDecorate %21 Decoration(Binding) 4 - OpDecorate %21 Decoration(DescriptorSet) 0 - OpDecorate %25 Decoration(Binding) 5 - OpDecorate %25 Decoration(DescriptorSet) 0 - OpDecorate %29 Decoration(Binding) 6 - OpDecorate %29 Decoration(DescriptorSet) 0 - OpDecorate %33 Decoration(Binding) 7 - OpDecorate %33 Decoration(DescriptorSet) 0 - OpDecorate %37 Decoration(Binding) 8 - OpDecorate %37 Decoration(DescriptorSet) 0 - OpDecorate %41 Decoration(Binding) 9 - OpDecorate %41 Decoration(DescriptorSet) 0 - OpDecorate %45 Decoration(Binding) 10 - OpDecorate %45 Decoration(DescriptorSet) 0 - %1 = OpTypeFloat 32 - %2 = OpTypeImage %1 Dim(Dim1D) 0 0 0 1 ImageFormat(Unknown) - %3 = OpTypeSampledImage %2 - %4 = OpTypePointer StorageClass(UniformConstant) %3 - %6 = OpTypeImage %1 Dim(Dim1D) 0 1 0 1 ImageFormat(Unknown) - %7 = OpTypeSampledImage %6 - %8 = OpTypePointer StorageClass(UniformConstant) %7 - %10 = OpTypeImage %1 Dim(Dim2D) 0 0 0 1 ImageFormat(Unknown) - %11 = OpTypeSampledImage %10 - %12 = OpTypePointer StorageClass(UniformConstant) %11 - %14 = OpTypeImage %1 Dim(Dim2D) 0 1 0 1 ImageFormat(Unknown) - %15 = OpTypeSampledImage %14 - %16 = OpTypePointer StorageClass(UniformConstant) %15 - %18 = OpTypeImage %1 Dim(Dim3D) 0 0 0 1 ImageFormat(Unknown) - %19 = OpTypeSampledImage %18 - %20 = OpTypePointer StorageClass(UniformConstant) %19 - %22 = OpTypeImage %1 Dim(Cube) 0 0 0 1 ImageFormat(Unknown) - %23 = OpTypeSampledImage %22 - %24 = OpTypePointer StorageClass(UniformConstant) %23 - %26 = OpTypeImage %1 Dim(Dim1D) 1 0 0 1 ImageFormat(Unknown) - %27 = OpTypeSampledImage %26 - %28 = OpTypePointer StorageClass(UniformConstant) %27 - %30 = OpTypeImage %1 Dim(Dim1D) 1 1 0 1 ImageFormat(Unknown) - %31 = OpTypeSampledImage %30 - %32 = OpTypePointer StorageClass(UniformConstant) %31 - %34 = OpTypeImage %1 Dim(Dim2D) 1 0 0 1 ImageFormat(Unknown) - %35 = OpTypeSampledImage %34 - %36 = OpTypePointer StorageClass(UniformConstant) %35 - %38 = OpTypeImage %1 Dim(Dim2D) 1 1 0 1 ImageFormat(Unknown) - %39 = OpTypeSampledImage %38 - %40 = OpTypePointer StorageClass(UniformConstant) %39 - %42 = OpTypeImage %1 Dim(Cube) 1 0 0 1 ImageFormat(Unknown) - %43 = OpTypeSampledImage %42 - %44 = OpTypePointer StorageClass(UniformConstant) %43 - %46 = OpTypeVoid - %47 = OpTypeFunction %46 - %48 = OpConstant %1 f32(0.5) - %49 = OpTypePointer StorageClass(Function) %1 - %50 = OpConstant %1 f32(0) - %51 = OpConstant %1 f32(1) - %52 = OpTypeVector %1 2 - %53 = OpTypePointer StorageClass(Function) %52 - %54 = OpConstant %1 f32(2) - %55 = OpTypeVector %1 3 - %56 = OpTypePointer StorageClass(Function) %55 - %57 = OpTypeVector %1 4 - %58 = OpTypePointer StorageClass(Function) %57 - %5 = OpVariable %4 StorageClass(UniformConstant) - %9 = OpVariable %8 StorageClass(UniformConstant) - %13 = OpVariable %12 StorageClass(UniformConstant) - %17 = OpVariable %16 StorageClass(UniformConstant) - %21 = OpVariable %20 StorageClass(UniformConstant) - %25 = OpVariable %24 StorageClass(UniformConstant) - %29 = OpVariable %28 StorageClass(UniformConstant) - %33 = OpVariable %32 StorageClass(UniformConstant) - %37 = OpVariable %36 StorageClass(UniformConstant) - %41 = OpVariable %40 StorageClass(UniformConstant) - %45 = OpVariable %44 StorageClass(UniformConstant) - %59 = OpFunction %46 FunctionControl(0) %47 - %60 = OpLabel - %61 = OpVariable %49 StorageClass(Function) - %62 = OpVariable %49 StorageClass(Function) - %63 = OpVariable %53 StorageClass(Function) - %64 = OpVariable %56 StorageClass(Function) - %65 = OpVariable %58 StorageClass(Function) - %66 = OpVariable %58 StorageClass(Function) - %67 = OpVariable %58 StorageClass(Function) - %68 = OpVariable %58 StorageClass(Function) - %69 = OpVariable %58 StorageClass(Function) - %70 = OpVariable %58 StorageClass(Function) - %71 = OpVariable %49 StorageClass(Function) - %72 = OpVariable %49 StorageClass(Function) - %73 = OpVariable %49 StorageClass(Function) - %74 = OpVariable %49 StorageClass(Function) - %75 = OpVariable %49 StorageClass(Function) - OpStore %61 %48 - OpStore %62 %50 - %76 = OpCompositeConstruct %52 %50 %51 - OpStore %63 %76 - %77 = OpCompositeConstruct %55 %50 %51 %54 - OpStore %64 %77 - %78 = OpLoad %3 %5 - %79 = OpLoad %1 %62 - %80 = OpImageSampleImplicitLod %57 %78 %79 - OpStore %65 %80 - %81 = OpLoad %7 %9 - %82 = OpLoad %52 %63 - %83 = OpImageSampleImplicitLod %57 %81 %82 - OpStore %66 %83 - %84 = OpLoad %11 %13 - %85 = OpLoad %52 %63 - %86 = OpImageSampleImplicitLod %57 %84 %85 - OpStore %67 %86 - %87 = OpLoad %15 %17 - %88 = OpLoad %55 %64 - %89 = OpImageSampleImplicitLod %57 %87 %88 - OpStore %68 %89 - %90 = OpLoad %19 %21 - %91 = OpLoad %55 %64 - %92 = OpImageSampleImplicitLod %57 %90 %91 - OpStore %69 %92 - %93 = OpLoad %23 %25 - %94 = OpLoad %55 %64 - %95 = OpImageSampleImplicitLod %57 %93 %94 - OpStore %70 %95 - %96 = OpLoad %27 %29 - %97 = OpLoad %1 %62 - %98 = OpLoad %1 %61 - %99 = OpImageSampleDrefImplicitLod %1 %96 %97 %98 - OpStore %71 %99 -%100 = OpLoad %31 %33 -%101 = OpLoad %52 %63 -%102 = OpLoad %1 %61 -%103 = OpImageSampleDrefImplicitLod %1 %100 %101 %102 - OpStore %72 %103 -%104 = OpLoad %35 %37 -%105 = OpLoad %52 %63 -%106 = OpLoad %1 %61 -%107 = OpImageSampleDrefImplicitLod %1 %104 %105 %106 - OpStore %73 %107 -%108 = OpLoad %39 %41 -%109 = OpLoad %55 %64 -%110 = OpLoad %1 %61 -%111 = OpImageSampleDrefImplicitLod %1 %108 %109 %110 - OpStore %74 %111 -%112 = OpLoad %43 %45 -%113 = OpLoad %55 %64 -%114 = OpLoad %1 %61 -%115 = OpImageSampleDrefImplicitLod %1 %112 %113 %114 - OpStore %75 %115 - OpReturn - OpFunctionEnd)", {}, {}, true); + OpCapability Capability(Shader) + OpCapability Capability(Sampled1D) + OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) + OpEntryPoint ExecutionModel(Fragment) %47 "main" + OpExecutionMode %47 ExecutionMode(OriginUpperLeft) + OpSource SourceLanguage(NZSL) 4198400 + OpSourceExtension "Version: 1.1" + OpName %5 "tex1D" + OpName %9 "tex2D" + OpName %13 "tex2DArray" + OpName %17 "tex3D" + OpName %21 "texCube" + OpName %25 "tex2DDepth" + OpName %29 "tex2DArrayDepth" + OpName %33 "texCubeDepth" + OpName %47 "main" + OpDecorate %5 Decoration(Binding) 0 + OpDecorate %5 Decoration(DescriptorSet) 0 + OpDecorate %9 Decoration(Binding) 1 + OpDecorate %9 Decoration(DescriptorSet) 0 + OpDecorate %13 Decoration(Binding) 2 + OpDecorate %13 Decoration(DescriptorSet) 0 + OpDecorate %17 Decoration(Binding) 3 + OpDecorate %17 Decoration(DescriptorSet) 0 + OpDecorate %21 Decoration(Binding) 4 + OpDecorate %21 Decoration(DescriptorSet) 0 + OpDecorate %25 Decoration(Binding) 5 + OpDecorate %25 Decoration(DescriptorSet) 0 + OpDecorate %29 Decoration(Binding) 6 + OpDecorate %29 Decoration(DescriptorSet) 0 + OpDecorate %33 Decoration(Binding) 7 + OpDecorate %33 Decoration(DescriptorSet) 0 + %1 = OpTypeFloat 32 + %2 = OpTypeImage %1 Dim(Dim1D) 0 0 0 1 ImageFormat(Unknown) + %3 = OpTypeSampledImage %2 + %4 = OpTypePointer StorageClass(UniformConstant) %3 + %6 = OpTypeImage %1 Dim(Dim2D) 0 0 0 1 ImageFormat(Unknown) + %7 = OpTypeSampledImage %6 + %8 = OpTypePointer StorageClass(UniformConstant) %7 +%10 = OpTypeImage %1 Dim(Dim2D) 0 1 0 1 ImageFormat(Unknown) +%11 = OpTypeSampledImage %10 +%12 = OpTypePointer StorageClass(UniformConstant) %11 +%14 = OpTypeImage %1 Dim(Dim3D) 0 0 0 1 ImageFormat(Unknown) +%15 = OpTypeSampledImage %14 +%16 = OpTypePointer StorageClass(UniformConstant) %15 +%18 = OpTypeImage %1 Dim(Cube) 0 0 0 1 ImageFormat(Unknown) +%19 = OpTypeSampledImage %18 +%20 = OpTypePointer StorageClass(UniformConstant) %19 +%22 = OpTypeImage %1 Dim(Dim2D) 1 0 0 1 ImageFormat(Unknown) +%23 = OpTypeSampledImage %22 +%24 = OpTypePointer StorageClass(UniformConstant) %23 +%26 = OpTypeImage %1 Dim(Dim2D) 1 1 0 1 ImageFormat(Unknown) +%27 = OpTypeSampledImage %26 +%28 = OpTypePointer StorageClass(UniformConstant) %27 +%30 = OpTypeImage %1 Dim(Cube) 1 0 0 1 ImageFormat(Unknown) +%31 = OpTypeSampledImage %30 +%32 = OpTypePointer StorageClass(UniformConstant) %31 +%34 = OpTypeVoid +%35 = OpTypeFunction %34 +%36 = OpConstant %1 f32(0.5) +%37 = OpTypePointer StorageClass(Function) %1 +%38 = OpConstant %1 f32(0) +%39 = OpConstant %1 f32(1) +%40 = OpTypeVector %1 2 +%41 = OpTypePointer StorageClass(Function) %40 +%42 = OpConstant %1 f32(2) +%43 = OpTypeVector %1 3 +%44 = OpTypePointer StorageClass(Function) %43 +%45 = OpTypeVector %1 4 +%46 = OpTypePointer StorageClass(Function) %45 + %5 = OpVariable %4 StorageClass(UniformConstant) + %9 = OpVariable %8 StorageClass(UniformConstant) +%13 = OpVariable %12 StorageClass(UniformConstant) +%17 = OpVariable %16 StorageClass(UniformConstant) +%21 = OpVariable %20 StorageClass(UniformConstant) +%25 = OpVariable %24 StorageClass(UniformConstant) +%29 = OpVariable %28 StorageClass(UniformConstant) +%33 = OpVariable %32 StorageClass(UniformConstant) +%47 = OpFunction %34 FunctionControl(0) %35 +%48 = OpLabel +%49 = OpVariable %37 StorageClass(Function) +%50 = OpVariable %37 StorageClass(Function) +%51 = OpVariable %41 StorageClass(Function) +%52 = OpVariable %44 StorageClass(Function) +%53 = OpVariable %46 StorageClass(Function) +%54 = OpVariable %46 StorageClass(Function) +%55 = OpVariable %46 StorageClass(Function) +%56 = OpVariable %46 StorageClass(Function) +%57 = OpVariable %46 StorageClass(Function) +%58 = OpVariable %37 StorageClass(Function) +%59 = OpVariable %37 StorageClass(Function) +%60 = OpVariable %37 StorageClass(Function) + OpStore %49 %36 + OpStore %50 %38 +%61 = OpCompositeConstruct %40 %38 %39 + OpStore %51 %61 +%62 = OpCompositeConstruct %43 %38 %39 %42 + OpStore %52 %62 +%63 = OpLoad %3 %5 +%64 = OpLoad %1 %50 +%65 = OpImageSampleImplicitLod %45 %63 %64 + OpStore %53 %65 +%66 = OpLoad %7 %9 +%67 = OpLoad %40 %51 +%68 = OpImageSampleImplicitLod %45 %66 %67 + OpStore %54 %68 +%69 = OpLoad %11 %13 +%70 = OpLoad %43 %52 +%71 = OpImageSampleImplicitLod %45 %69 %70 + OpStore %55 %71 +%72 = OpLoad %15 %17 +%73 = OpLoad %43 %52 +%74 = OpImageSampleImplicitLod %45 %72 %73 + OpStore %56 %74 +%75 = OpLoad %19 %21 +%76 = OpLoad %43 %52 +%77 = OpImageSampleImplicitLod %45 %75 %76 + OpStore %57 %77 +%78 = OpLoad %23 %25 +%79 = OpLoad %40 %51 +%80 = OpLoad %1 %49 +%81 = OpImageSampleDrefImplicitLod %1 %78 %79 %80 + OpStore %58 %81 +%82 = OpLoad %27 %29 +%83 = OpLoad %43 %52 +%84 = OpLoad %1 %49 +%85 = OpImageSampleDrefImplicitLod %1 %82 %83 %84 + OpStore %59 %85 +%86 = OpLoad %31 %33 +%87 = OpLoad %43 %52 +%88 = OpLoad %1 %49 +%89 = OpImageSampleDrefImplicitLod %1 %86 %87 %88 + OpStore %60 %89 + OpReturn + OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@group(0) @binding(0) var tex1D: texture_1d; +@group(0) @binding(1) var tex1DSampler: sampler; +@group(0) @binding(2) var tex2D: texture_2d; +@group(0) @binding(3) var tex2DSampler: sampler; +@group(0) @binding(4) var tex2DArray: texture_2d_array; +@group(0) @binding(5) var tex2DArraySampler: sampler; +@group(0) @binding(6) var tex3D: texture_3d; +@group(0) @binding(7) var tex3DSampler: sampler; +@group(0) @binding(8) var texCube: texture_cube; +@group(0) @binding(9) var texCubeSampler: sampler; +@group(0) @binding(10) var tex2DDepth: texture_depth_2d; +@group(0) @binding(11) var tex2DDepthSampler: sampler_comparison; +@group(0) @binding(12) var tex2DArrayDepth: texture_depth_2d_array; +@group(0) @binding(13) var tex2DArrayDepthSampler: sampler_comparison; +@group(0) @binding(14) var texCubeDepth: texture_depth_cube; +@group(0) @binding(15) var texCubeDepthSampler: sampler_comparison; + +@fragment +fn main() +{ + var depth: f32 = 0.5; + var uv1f: f32 = 0.0; + var uv2f: vec2 = vec2(0.0, 1.0); + var uv3f: vec3 = vec3(0.0, 1.0, 2.0); + var sampleResult1: vec4 = textureSample(tex1D, tex1DSampler, uv1f); + var sampleResult2: vec4 = textureSample(tex2D, tex2DSampler, uv2f); + var sampleResult3: vec4 = textureSample(tex2DArray, tex2DArraySampler, uv3f.xy, u32(uv3f.z)); + var sampleResult4: vec4 = textureSample(tex3D, tex3DSampler, uv3f); + var sampleResult5: vec4 = textureSample(texCube, texCubeSampler, uv3f); + var depthSampleResult3: f32 = textureSampleCompare(tex2DDepth, tex2DDepthSampler, uv2f, depth); + var depthSampleResult4: f32 = textureSampleCompare(tex2DArrayDepth, tex2DArrayDepthSampler, uv3f.xy, u32(uv3f.z), depth); + var depthSampleResult5: f32 = textureSampleCompare(texCubeDepth, texCubeDepthSampler, uv3f, depth); +} +)"); + } + + WHEN("testing texture 1d array intrinsics") + { + std::string_view nzslSource = R"( +[nzsl_version("1.1")] +[feature(texture1D)] +module; + +[auto_binding] +external +{ + tex1DArray: sampler1D_array[f32], +} + +[entry(frag)] +fn main() +{ + let sampleResult = tex1DArray.Sample(vec2[f32](0.0, 1.0)); +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + ResolveModule(*shaderModule); + + // sampler1D and sampler1D_array are not supported by GLSL ES + nzsl::GlslWriter::Environment glslEnv; + glslEnv.glES = false; + + ExpectGLSL(*shaderModule, R"( +uniform sampler1DArray tex1DArray; + +void main() +{ + vec4 sampleResult = texture(tex1DArray, vec2(0.0, 1.0)); +} +)", {}, glslEnv); + + ExpectNZSL(*shaderModule, R"( +[auto_binding(true)] +external +{ + [set(0), binding(0)] tex1DArray: sampler1D_array[f32] +} + +[entry(frag)] +fn main() +{ + let sampleResult: vec4[f32] = tex1DArray.Sample(vec2[f32](0.0, 1.0)); +} +)"); + + ExpectSPIRV(*shaderModule, R"( + OpCapability Capability(Shader) + OpCapability Capability(Sampled1D) + OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) + OpEntryPoint ExecutionModel(Fragment) %13 "main" + OpExecutionMode %13 ExecutionMode(OriginUpperLeft) + OpSource SourceLanguage(NZSL) 4198400 + OpSourceExtension "Version: 1.1" + OpName %5 "tex1DArray" + OpName %13 "main" + OpDecorate %5 Decoration(Binding) 0 + OpDecorate %5 Decoration(DescriptorSet) 0 + %1 = OpTypeFloat 32 + %2 = OpTypeImage %1 Dim(Dim1D) 0 1 0 1 ImageFormat(Unknown) + %3 = OpTypeSampledImage %2 + %4 = OpTypePointer StorageClass(UniformConstant) %3 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpConstant %1 f32(0) + %9 = OpConstant %1 f32(1) +%10 = OpTypeVector %1 2 +%11 = OpTypeVector %1 4 +%12 = OpTypePointer StorageClass(Function) %11 + %5 = OpVariable %4 StorageClass(UniformConstant) +%13 = OpFunction %6 FunctionControl(0) %7 +%14 = OpLabel +%15 = OpVariable %12 StorageClass(Function) +%16 = OpLoad %3 %5 +%17 = OpCompositeConstruct %10 %8 %9 +%18 = OpImageSampleImplicitLod %11 %16 %17 + OpStore %15 %18 + OpReturn + OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter wgslWriter; + CHECK_THROWS_WITH(wgslWriter.Generate(*shaderModule), "texture 1D array are not supported by WGSL"); } WHEN("testing math intrinsics") @@ -1142,6 +1233,135 @@ fn main() OpStore %166 %419 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var d1: f64 = 4.2; + var d2: f64 = 133.699999999999989; + var d3: f64 = -123.400000000000006; + var f1: f32 = 4.2; + var f2: f32 = 133.699997; + var f3: f32 = -123.400002; + var i1: i32 = 42; + var i2: i32 = 1337; + var i3: i32 = -1234; + var u1: u32 = 42u; + var u2: u32 = 1337u; + var u3: u32 = 123456789u; + var uv: vec2 = vec2(0.0, 1.0); + var v1: vec3 = vec3(0.0, 1.0, 2.0); + var v2: vec3 = vec3(2.0, 1.0, 0.0); + var v3: vec3 = vec3(1.0, 0.0, 2.0); + var dv1: vec3 = vec3(0.0, 1.0, 2.0); + var dv2: vec3 = vec3(2.0, 1.0, 0.0); + var dv3: vec3 = vec3(1.0, 0.0, 2.0); + var iv1: vec3 = vec3(0, 1, 2); + var iv2: vec3 = vec3(2, 1, 0); + var iv3: vec3 = vec3(1, 0, 2); + var uv1: vec3 = vec3(0u, 1u, 2u); + var uv2: vec3 = vec3(2u, 1u, 0u); + var uv3: vec3 = vec3(1u, 0u, 2u); + var absResult1: f32 = abs(f1); + var absResult2: vec3 = abs(v1); + var absResult3: f64 = abs(d1); + var absResult3_2: vec3 = abs(dv1); + var ceilResult1: f32 = ceil(f1); + var ceilResult2: vec3 = ceil(v1); + var ceilResult3: f64 = ceil(d1); + var ceilResult4: vec3 = ceil(dv1); + var clampResult1: f32 = clamp(f1, f3, f2); + var clampResult2: vec3 = clamp(v1, v3, v2); + var clampResult3: f64 = clamp(d1, d3, d2); + var clampResult4: vec3 = clamp(dv1, dv3, dv2); + var crossResult1: vec3 = cross(v1, v2); + var crossResult2: vec3 = cross(dv1, dv2); + var distanceResult1: f32 = distance(v1, v2); + var distanceResult2: f64 = distance(dv1, dv2); + var dotResult1: f32 = dot(v1, v2); + var dotResult2: f64 = dot(dv1, dv2); + var expResult1: vec3 = exp(v1); + var expResult2: f32 = exp(f1); + var exp2Result1: vec3 = exp2(v1); + var exp2Result2: f32 = exp2(f1); + var floorResult1: f32 = floor(f1); + var floorResult2: vec3 = floor(v1); + var floorResult3: f64 = floor(d1); + var floorResult4: vec3 = floor(dv1); + var fractResult1: f32 = fract(f1); + var fractResult2: vec3 = fract(v1); + var fractResult3: f64 = fract(d1); + var fractResult4: vec3 = fract(dv1); + var rsqrtResult1: f32 = inverseSqrt(f1); + var rsqrtResult2: vec3 = inverseSqrt(v1); + var rsqrtResult3: f64 = inverseSqrt(d1); + var rsqrtResult4: vec3 = inverseSqrt(dv1); + var lengthResult1: f32 = length(v1); + var lengthResult2: f64 = length(dv1); + var lerpResult1: f32 = mix(f1, f3, f2); + var lerpResult2: vec3 = mix(v1, v3, v2); + var lerpResult3: f64 = mix(d1, d3, d2); + var lerpResult4: vec3 = mix(dv1, dv3, dv2); + var logResult1: vec3 = log(v1); + var logResult2: f32 = log(f1); + var log2Result1: vec3 = log2(v1); + var log2Result2: f32 = log2(f1); + var maxResult1: f32 = max(f1, f2); + var maxResult2: i32 = max(i1, i2); + var maxResult3: u32 = max(u1, u2); + var maxResult4: vec3 = max(v1, v2); + var maxResult5: vec3 = max(dv1, dv2); + var maxResult6: vec3 = max(iv1, iv2); + var maxResult7: vec3 = max(uv1, uv2); + var minResult1: f32 = min(f1, f2); + var minResult2: i32 = min(i1, i2); + var minResult3: u32 = min(u1, u2); + var minResult4: vec3 = min(v1, v2); + var minResult5: vec3 = min(dv1, dv2); + var minResult6: vec3 = min(iv1, iv2); + var minResult7: vec3 = min(uv1, uv2); + var normalizeResult1: vec3 = normalize(v1); + var normalizeResult2: vec3 = normalize(dv1); + var powResult1: f32 = pow(f1, f2); + var powResult2: vec3 = pow(v1, v2); + var reflectResult1: vec3 = reflect(v1, v2); + var reflectResult2: vec3 = reflect(dv1, dv2); + var roundResult1: f32 = round(f1); + var roundResult2: vec3 = round(v1); + var roundResult3: f64 = round(d1); + var roundResult4: vec3 = round(dv1); + var roundevenResult1: f32 = round(f1); + var roundevenResult2: vec3 = round(v1); + var roundevenResult3: f64 = round(d1); + var roundevenResult4: vec3 = round(dv1); + var signResult1: f32 = sign(f1); + var signResult2: i32 = sign(i1); + var signResult3: f64 = sign(d1); + var signResult4: vec3 = sign(v1); + var signResult5: vec3 = sign(dv1); + var signResult6: vec3 = sign(iv1); + var smoothStepResult1: f32 = smoothstep(f1, f2, f3); + var smoothStepResult2: vec3 = smoothstep(v1, v2, v3); + var smoothStepResult1_2: f64 = smoothstep(d1, d2, d3); + var smoothStepResult2_2: vec3 = smoothstep(dv1, dv2, dv3); + var sqrtResult1: f32 = sqrt(f1); + var sqrtResult2: vec3 = sqrt(v1); + var sqrtResult3: f64 = sqrt(d1); + var sqrtResult4: vec3 = sqrt(dv1); + var stepResult1: f32 = step(f1, f2); + var stepResult2: vec3 = step(v1, v2); + var stepResult1_2: f64 = step(d1, d2); + var stepResult2_2: vec3 = step(dv1, dv2); + var truncResult1: f32 = trunc(f1); + var truncResult2: vec3 = trunc(v1); + var truncResult3: f64 = trunc(d1); + var truncResult4: vec3 = trunc(dv1); +} +)", {}, wgslEnv); } WHEN("testing matrix intrinsics") @@ -1340,6 +1560,40 @@ fn main() OpStore %66 %106 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var _nzsl_matrix: mat4x4; + _nzsl_matrix[0u] = vec4(0.0, 1.0, 2.0, 3.0); + _nzsl_matrix[1u] = vec4(4.0, 5.0, 6.0, 7.0); + _nzsl_matrix[2u] = vec4(8.0, 9.0, 10.0, 11.0); + _nzsl_matrix[3u] = vec4(12.0, 13.0, 14.0, 15.0); + var m1: mat4x4 = _nzsl_matrix; + var _nzsl_matrix_2: mat2x3; + _nzsl_matrix_2[0u] = vec3(0.0, 1.0, 2.0); + _nzsl_matrix_2[1u] = vec3(3.0, 4.0, 5.0); + var m2: mat2x3 = _nzsl_matrix_2; + var _nzsl_matrix_3: mat3x3; + _nzsl_matrix_3[0u] = vec3(0.0, 1.0, 2.0); + _nzsl_matrix_3[1u] = vec3(3.0, 4.0, 5.0); + _nzsl_matrix_3[2u] = vec3(6.0, 7.0, 8.0); + var m3: mat3x3 = _nzsl_matrix_3; + var _nzsl_matrix_4: mat3x2; + _nzsl_matrix_4[0u] = vec2(0.0, 1.0); + _nzsl_matrix_4[1u] = vec2(2.0, 3.0); + _nzsl_matrix_4[2u] = vec2(4.0, 5.0); + var m4: mat3x2 = _nzsl_matrix_4; + var inverseResult1: mat4x4 = _nzslMatrixInverse4x4f32(m1); + var inverseResult2: mat3x3 = _nzslMatrixInverse3x3f64(m3); + var transposeResult1: mat3x2 = transpose(m2); + var transposeResult2: mat2x3 = transpose(m4); +} +)", {}, wgslEnv); } WHEN("testing trigonometry intrinsics") @@ -1562,6 +1816,49 @@ fn main() OpStore %57 %115 OpReturn OpFunctionEnd)", {}, {}, true); + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +fn main() +{ + var d1: f64 = 42.0; + var d2: f64 = 1337.0; + var f1: f32 = 42.0; + var f2: f32 = 1337.0; + var v1: vec3 = vec3(0.0, 1.0, 2.0); + var v2: vec3 = vec3(2.0, 1.0, 0.0); + var dv1: vec3 = vec3(0.0, 1.0, 2.0); + var dv2: vec3 = vec3(2.0, 1.0, 0.0); + var acosResult1: f32 = acos(f1); + var acosResult2: vec3 = acos(v1); + var acoshResult1: f32 = acosh(f1); + var acoshResult2: vec3 = acosh(v1); + var asinResult1: f32 = asin(f1); + var asinResult2: vec3 = asin(v1); + var asinhResult1: f32 = asinh(f1); + var asinhResult2: vec3 = asinh(v1); + var atanResult1: f32 = atan(f1); + var atanResult2: vec3 = atan(v1); + var atan2Result1: f32 = atan2(f1, f2); + var atan2Result2: vec3 = atan2(v1, v2); + var atanhResult1: f32 = atanh(f1); + var atanhResult2: vec3 = atanh(v1); + var cosResult1: f32 = cos(f1); + var cosResult2: vec3 = cos(v1); + var coshResult1: f32 = cosh(f1); + var coshResult2: vec3 = cosh(v1); + var deg2radResult1: f32 = radians(f1); + var deg2radResult2: vec3 = radians(v1); + var rad2degResult1: f32 = degrees(f1); + var rad2degResult2: vec3 = degrees(v1); + var sinResult1: f32 = sin(f1); + var sinResult2: vec3 = sin(v1); + var sinhResult1: f32 = sinh(f1); + var sinhResult2: vec3 = sinh(v1); +} +)", {}, wgslEnv); } WHEN("testing select intrinsic") @@ -2116,6 +2413,50 @@ fn main() OpReturn OpFunctionEnd)", {}, env, true); } + + nzsl::WgslWriter::Environment wgslEnv; + wgslEnv.featuresCallback = [](std::string_view) { return true; }; + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var b1: bool = false; + var b2: bool = true; + var d1: f64 = 4.2; + var d2: f64 = 133.699999999999989; + var f1: f32 = 4.2; + var f2: f32 = 133.699997; + var i1: i32 = 42; + var i2: i32 = 1337; + var u1: u32 = 42u; + var u2: u32 = 1337u; + var v1: vec3 = vec3(0.0, 1.0, 2.0); + var v2: vec3 = vec3(2.0, 1.0, 0.0); + var bv1: vec3 = vec3(true, false, true); + var bv2: vec3 = vec3(false, false, true); + var dv1: vec3 = vec3(0.0, 1.0, 2.0); + var dv2: vec3 = vec3(2.0, 1.0, 0.0); + var iv1: vec3 = vec3(0, 1, 2); + var iv2: vec3 = vec3(2, 1, 0); + var uv1: vec3 = vec3(0u, 1u, 2u); + var uv2: vec3 = vec3(2u, 1u, 0u); + var result: f64 = select(d2, d1, b1); + var result_2: f32 = select(f2, f1, b2); + var result_3: i32 = select(i2, i1, b1); + var result_4: u32 = select(u2, u1, b2); + var result_5: vec3 = select(v2, v1, vec3(b1)); + var result_6: vec3 = select(bv2, bv1, vec3(b2)); + var result_7: vec3 = select(dv2, dv1, vec3(b1)); + var result_8: vec3 = select(iv2, iv1, vec3(b2)); + var result_9: vec3 = select(uv2, uv1, vec3(b1)); + var result_10: vec3 = select(v2, v1, bv1); + var result_11: vec3 = select(bv2, bv1, bv2); + var result_12: vec3 = select(dv2, dv1, bv1); + var result_13: vec3 = select(iv2, iv1, bv2); + var result_14: vec3 = select(uv2, uv1, bv1); +} +)", {}, wgslEnv); } WHEN("testing all/any/not intrinsics") @@ -2187,6 +2528,17 @@ fn main() OpStore %14 %21 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: vec3 = vec3(true, false, false); + var r: bool = all(x); + var r_2: bool = any(x); + var r_3: vec3 = !(x); +} +)"); } WHEN("testing isinf/isnan intrinsics") @@ -2254,5 +2606,25 @@ fn main() OpStore %17 %21 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +fn _nzslRatiof32(n: f32, d: f32) -> f32 +{ + return n / d; +} + +fn _nzslInfinityf32() -> f32 +{ + return _nzslRatiof32(1.0, 0.0); +} + +@fragment +fn main() +{ + var x: vec3 = vec3(1.0, 2.0, 3.0); + var r: vec3 = vec3(x.x == _nzslInfinityf32(), x.y == _nzslInfinityf32(), x.z == _nzslInfinityf32()); + var r_2: vec3 = x != x; +} +)"); } } diff --git a/tests/src/Tests/LayoutTests.cpp b/tests/src/Tests/LayoutTests.cpp index ee64a641..cd807b1f 100644 --- a/tests/src/Tests/LayoutTests.cpp +++ b/tests/src/Tests/LayoutTests.cpp @@ -119,6 +119,25 @@ fn main() OpReturn OpFunctionEnd )", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +// std140 layout +struct Foo +{ + v0: vec3, + v1: vec3, + v2: f32, + _padding0: f32 +} + +@group(0) @binding(0) var foo: Foo; + +@fragment +fn main() +{ + var value: f32 = 0.0; +} +)"); } SECTION("std430") @@ -282,6 +301,32 @@ fn main() OpReturn OpFunctionEnd )", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +struct Bar +{ + v0: vec2 +} + +struct Foo +{ + v0: f32, + v1: vec3, + v2: array, + v3: vec2, + v4: Bar, + v5: vec3, + v6: f32 +} + +@group(0) @binding(0) var foo: Foo; + +@fragment +fn main() +{ + var value: f32 = 0.0; +} +)"); } SECTION("scalar") @@ -415,6 +460,21 @@ fn main() OpFunctionEnd )", {}, spirvEnv, true, spvValidatorOptions); } - } + ExpectWGSL(*shaderModule, R"( +struct Foo +{ + v0: vec3, + v1: vec3 +} + +@group(0) @binding(0) var foo: Foo; + +@fragment +fn main() +{ + var value: f32 = 0.0; +} +)"); + } } diff --git a/tests/src/Tests/LiteralTests.cpp b/tests/src/Tests/LiteralTests.cpp index 1ccdd38f..2ea0c51c 100644 --- a/tests/src/Tests/LiteralTests.cpp +++ b/tests/src/Tests/LiteralTests.cpp @@ -85,6 +85,21 @@ fn foo() OpStore %31 %20 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn foo() +{ + var bar: f32 = -1.0; + var _nzsl_cachedResult: f32 = 1.0 + bar; + var bar_2: vec3 = vec3(_nzsl_cachedResult, _nzsl_cachedResult, _nzsl_cachedResult); + var bar_3: vec3 = vec3(2.0, 2.0, 2.0); + var bar_4: f32 = (max(1.0, 2.0)) + (min(2.0, 1.0)); + var bar_5: f32 = max(min(1.0, 2.0), 3.0); + var bar_6: u32 = max(1u, 2u); + var bar_7: vec3 = vec3(1, 2, 3); +} +)"); } @@ -140,5 +155,14 @@ fn foo() OpStore %10 %6 OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn foo() +{ + var bar: u32 = 4u; + var bar_2: u32 = 5u; +} +)"); } } diff --git a/tests/src/Tests/LoopTests.cpp b/tests/src/Tests/LoopTests.cpp index c00f5446..085bddf3 100644 --- a/tests/src/Tests/LoopTests.cpp +++ b/tests/src/Tests/LoopTests.cpp @@ -94,6 +94,21 @@ OpLabel OpReturn OpFunctionEnd)"); + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var value: f32 = 0.0; + var i: i32 = 0; + while (i < 10) + { + value += 0.1; + i += 1; + } + +} +)"); + WHEN("using break and continue") { std::string_view nzslSource2 = R"( @@ -237,6 +252,33 @@ fn main() %25 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule2, R"( +@fragment +fn main() +{ + var value: f32 = 0.0; + var value2: f32 = 0.0; + var i: i32 = 0; + while (i < 10) + { + if (i >= 8) + { + break; + } + + value += 0.1; + i += 1; + if (i == 4) + { + continue; + } + + value2 += value; + } + +} +)"); } } @@ -409,6 +451,47 @@ OpLabel OpReturn OpFunctionEnd)"); + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: i32 = 0; + { + var v: i32 = 0; + var _nzsl_to: i32 = 10; + while (v < _nzsl_to) + { + x += v; + { + var v_2: i32 = 5; + var _nzsl_to_2: i32 = 7; + while (v_2 < _nzsl_to_2) + { + x += v_2; + v_2 += 1; + } + + } + + v += 1; + } + + } + + { + var v: i32 = 0; + var _nzsl_to: i32 = 20; + while (v < _nzsl_to) + { + x += v; + v += 1; + } + + } + +} +)"); + WHEN("using break and continue") { @@ -534,6 +617,35 @@ fn main() %18 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule2, R"( +@fragment +fn main() +{ + var x: i32 = 0; + { + var v: i32 = 0; + var _nzsl_to: i32 = 10; + while (v < _nzsl_to) + { + if (v == 4) + { + continue; + } + + x += v; + if (v >= 8) + { + break; + } + + v += 1; + } + + } + +} +)"); } } @@ -622,6 +734,26 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var x: i32 = 0; + { + var v: i32 = 0; + var _nzsl_to: i32 = 10; + var _nzsl_step: i32 = 2; + while (v < _nzsl_to) + { + x += v; + v += _nzsl_step; + } + + } + +} +)"); } WHEN("using a for-each") @@ -718,7 +850,40 @@ OpLabel OpReturn OpFunctionEnd)"); + ExpectWGSL(*shaderModule, R"( +struct f32_stride16 +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// std140 layout +struct inputStruct +{ + value: array +} +@group(0) @binding(0) var data: inputStruct; + +@fragment +fn main() +{ + var x: f32 = 0.0; + { + var _nzsl_counter: u32 = 0u; + while (_nzsl_counter < 10u) + { + var v: f32 = data.value[_nzsl_counter].value; + x += v; + _nzsl_counter += 1u; + } + + } + +} +)"); WHEN("using break and continue") { @@ -856,6 +1021,51 @@ fn main() %27 = OpLabel OpReturn OpFunctionEnd)", {}, {}, true); + + ExpectWGSL(*shaderModule2, R"( +struct f32_stride16 +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// std140 layout +struct inputStruct +{ + value: array +} + +@group(0) @binding(0) var data: inputStruct; + +@fragment +fn main() +{ + var x: f32 = 0.0; + { + var _nzsl_counter: u32 = 0u; + while (_nzsl_counter < 10u) + { + var v: f32 = data.value[_nzsl_counter].value; + if (v < 0.0) + { + continue; + } + + x += v; + if (x >= 10.0) + { + break; + } + + _nzsl_counter += 1u; + } + + } + +} +)"); } } @@ -945,5 +1155,25 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var cascadeIndex: u32 = 0u; + var cascadeCount: u32 = 4u; + { + var index: u32 = 0u; + var _nzsl_to: u32 = cascadeCount; + while (index < _nzsl_to) + { + cascadeIndex = index; + index += 1u; + } + + } + +} +)"); } } diff --git a/tests/src/Tests/ModuleTests.cpp b/tests/src/Tests/ModuleTests.cpp index a447fa14..916d46fc 100644 --- a/tests/src/Tests/ModuleTests.cpp +++ b/tests/src/Tests/ModuleTests.cpp @@ -64,13 +64,13 @@ struct Unused {} [export] struct InputData { - value: f32 + [location(0)] value: f32 } [export] struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } )"; @@ -92,9 +92,10 @@ external ExtData fn main(input: InputData) -> OutputData { let data = ExtData.block.data; + let value = GetDataValue(data) * input.value * Pi; let output: OutputData; - output.value = GetDataValue(data) * input.value * Pi; + output.value = vec4[f32](value, value, value, value); return output; } )"; @@ -137,7 +138,7 @@ struct InputData_SimpleModule struct OutputData_SimpleModule { - float value; + vec4 value; }; // Main module @@ -151,20 +152,21 @@ layout(std140) uniform _nzslBindingExtData_block } ExtData_block; /**************** Inputs ****************/ -in float _nzslInvalue; +in float _nzslVarying0; // _nzslInvalue /*************** Outputs ***************/ -out float _nzslOutvalue; +layout(location = 0) out vec4 _nzslOutvalue; void main() { InputData_SimpleModule input_; - input_.value = _nzslInvalue; + input_.value = _nzslVarying0; Data_SimpleModule data; data.value = ExtData_block.data.value; + float value = ((GetDataValue_SimpleModule(data)) * input_.value) * 3.141592; OutputData_SimpleModule output_; - output_.value = ((GetDataValue_SimpleModule(data)) * input_.value) * 3.141592; + output_.value = vec4(value, value, value, value); _nzslOutvalue = output_.value; return; @@ -203,12 +205,12 @@ module _SimpleModule struct InputData { - value: f32 + [location(0)] value: f32 } struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } } @@ -231,8 +233,9 @@ external ExtData fn main(input: InputData) -> OutputData { let data: _SimpleModule.Data = ExtData.block.data; + let value: f32 = ((GetDataValue(data)) * input.value) * Pi; let output: OutputData; - output.value = ((GetDataValue(data)) * input.value) * Pi; + output.value = vec4[f32](value, value, value, value); return output; } )"); @@ -251,6 +254,9 @@ OpVariable OpVariable OpVariable OpVariable +OpVariable +OpAccessChain +OpCopyMemory OpAccessChain OpLoad OpAccessChain @@ -262,11 +268,73 @@ OpAccessChain OpLoad OpFMul OpFMul +OpStore +OpLoad +OpLoad +OpLoad +OpLoad +OpCompositeConstruct OpAccessChain OpStore OpLoad +OpCompositeExtract +OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +// Author "Sir Lynix" +// Description: "Main file" +// License: "MIT" + +// Author "Lynix" +// Description: "Simple \"module\" for testing" +// License: "Public domain" + +// Module SimpleModule +// std140 layout +struct _SimpleModule_Data +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// std140 layout +struct _SimpleModule_Block +{ + data: _SimpleModule_Data +} + +fn _SimpleModule_GetDataValue(data: _SimpleModule_Data) -> f32 +{ + return data.value; +} + +struct _SimpleModule_InputData +{ + @location(0) value: f32 +} + +struct _SimpleModule_OutputData +{ + @location(0) value: vec4 +} + +@group(0) @binding(0) var ExtData_block: _SimpleModule_Block; + +@fragment +fn main(input: _SimpleModule_InputData) -> _SimpleModule_OutputData +{ + var data: _SimpleModule_Data; + data.value = ExtData_block.data.value; + var value: f32 = ((_SimpleModule_GetDataValue(data)) * input.value) * 3.141592; + var output: _SimpleModule_OutputData; + output.value = vec4(value, value, value, value); + return output; +} +)"); } WHEN("Importing nested modules") @@ -308,13 +376,13 @@ module Modules.InputOutput; [export] struct InputData { - value: f32 + [location(0)] value: f32 } [export] struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } [export] @@ -340,8 +408,9 @@ external [entry(frag)] fn main(input: Input) -> OutputData { + let value = block.data.value * input.value; let output: OutputDataAlias; - output.value = block.data.value * input.value; + output.value = vec4[f32](value, value, value, value); return output; } )"; @@ -382,7 +451,7 @@ struct InputData_Modules_InputOutput struct OutputData_Modules_InputOutput { - float value; + vec4 value; }; // Main module @@ -393,18 +462,19 @@ layout(std140) uniform _nzslBindingblock } block; /**************** Inputs ****************/ -in float _nzslInvalue; +in float _nzslVarying0; // _nzslInvalue /*************** Outputs ***************/ -out float _nzslOutvalue; +layout(location = 0) out vec4 _nzslOutvalue; void main() { InputData_Modules_InputOutput input_; - input_.value = _nzslInvalue; + input_.value = _nzslVarying0; + float value = block.data.value * input_.value; OutputData_Modules_InputOutput output_; - output_.value = block.data.value * input_.value; + output_.value = vec4(value, value, value, value); _nzslOutvalue = output_.value; return; @@ -442,12 +512,12 @@ module _Modules_InputOutput { struct InputData { - value: f32 + [location(0)] value: f32 } struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } } @@ -467,8 +537,9 @@ external [entry(frag)] fn main(input: Input) -> OutputData { + let value: f32 = block.data.value * input.value; let output: OutputDataAlias; - output.value = block.data.value * input.value; + output.value = vec4[f32](value, value, value, value); return output; } )"); @@ -478,16 +549,67 @@ OpFunction OpLabel OpVariable OpVariable +OpVariable +OpAccessChain +OpCopyMemory OpAccessChain OpLoad OpAccessChain OpLoad OpFMul +OpStore +OpLoad +OpLoad +OpLoad +OpLoad +OpCompositeConstruct OpAccessChain OpStore OpLoad +OpCompositeExtract +OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +// Module Modules.Data +// std140 layout +struct _Modules_Data_Data +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} +// Module Modules.Block + +// std140 layout +struct _Modules_Block_Block +{ + data: _Modules_Data_Data +} +// Module Modules.InputOutput +struct _Modules_InputOutput_InputData +{ + @location(0) value: f32 +} + +struct _Modules_InputOutput_OutputData +{ + @location(0) value: vec4 +} + +@group(0) @binding(0) var block: _Modules_Block_Block; + +@fragment +fn main(input: _Modules_InputOutput_InputData) -> _Modules_InputOutput_OutputData +{ + var value: f32 = block.data.value * input.value; + var output: _Modules_InputOutput_OutputData; + output.value = vec4(value, value, value, value); + return output; +} +)"); } WHEN("Testing AST variable indices remapping") @@ -832,6 +954,70 @@ OpFunctionCall OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +// Module Modules.Data +// std140 layout +struct _Modules_Data_Light +{ + color: vec4, + intensities: vec2, + _padding0: f32, + _padding1: f32 +} + +// std140 layout +struct _Modules_Data_Lights +{ + lights: array<_Modules_Data_Light, 3> +} +// Module Modules.Func + +fn _Modules_Func_SumLightColor(lightData: _Modules_Data_Lights) -> vec4 +{ + var color: vec4 = vec4(0.0, 0.0, 0.0, 0.0); + { + var index: u32 = 0u; + var _nzsl_to: u32 = 3; + while (index < _nzsl_to) + { + color += lightData.lights[index].color; + index += 1u; + } + + } + + return color; +} + +fn _Modules_Func_SumLightIntensities(lightData: _Modules_Data_Lights) -> vec2 +{ + var intensities: vec2 = vec2(0, 0); + { + var _nzsl_counter: u32 = 0u; + while (_nzsl_counter < 3u) + { + var light: _Modules_Data_Light = lightData.lights[_nzsl_counter]; + intensities += light.intensities; + _nzsl_counter += 1u; + } + + } + + return intensities; +} + +@group(0) @binding(0) var lightData: _Modules_Data_Lights; + +@fragment +fn main() +{ + var data: _Modules_Data_Lights; + data.lights = lightData.lights; + var color: vec4 = _Modules_Func_SumLightColor(data); + var intensities: vec2 = _Modules_Func_SumLightIntensities(data); +} +)"); } WHEN("Testing forward vs deferred based on option") @@ -1156,13 +1342,13 @@ struct Unused {} [export] struct InputData { - value: f32 + [location(0)] value: f32 } [export] struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } )"; @@ -1184,9 +1370,10 @@ external ExtData fn main(input: Module.InputData) -> Module.OutputData { let data = ExtData.block.data; + let value = Module.GetDataValue(data) * input.value * Module.Pi; let output: Module.OutputData; - output.value = Module.GetDataValue(data) * input.value * Module.Pi; + output.value = vec4[f32](value, value, value, value); return output; } )"; @@ -1229,7 +1416,7 @@ struct InputData_Simple_Module struct OutputData_Simple_Module { - float value; + vec4 value; }; // Main module @@ -1243,20 +1430,21 @@ layout(std140) uniform _nzslBindingExtData_block } ExtData_block; /**************** Inputs ****************/ -in float _nzslInvalue; +in float _nzslVarying0; // _nzslInvalue /*************** Outputs ***************/ -out float _nzslOutvalue; +layout(location = 0) out vec4 _nzslOutvalue; void main() { InputData_Simple_Module input_; - input_.value = _nzslInvalue; + input_.value = _nzslVarying0; Data_Simple_Module data; data.value = ExtData_block.data.value; + float value = ((GetDataValue_Simple_Module(data)) * input_.value) * 3.141592; OutputData_Simple_Module output_; - output_.value = ((GetDataValue_Simple_Module(data)) * input_.value) * 3.141592; + output_.value = vec4(value, value, value, value); _nzslOutvalue = output_.value; return; @@ -1295,12 +1483,12 @@ module _Simple_Module struct InputData { - value: f32 + [location(0)] value: f32 } struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } } @@ -1315,8 +1503,9 @@ external ExtData fn main(input: Module.InputData) -> Module.OutputData { let data: Module.Data = ExtData.block.data; + let value: f32 = ((Module.GetDataValue(data)) * input.value) * Module.Pi; let output: Module.OutputData; - output.value = ((Module.GetDataValue(data)) * input.value) * Module.Pi; + output.value = vec4[f32](value, value, value, value); return output; } )"); @@ -1335,6 +1524,9 @@ OpVariable OpVariable OpVariable OpVariable +OpVariable +OpAccessChain +OpCopyMemory OpAccessChain OpLoad OpAccessChain @@ -1346,11 +1538,73 @@ OpAccessChain OpLoad OpFMul OpFMul +OpStore +OpLoad +OpLoad +OpLoad +OpLoad +OpCompositeConstruct OpAccessChain OpStore OpLoad +OpCompositeExtract +OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +// Author "Sir Lynix" +// Description: "Main file" +// License: "MIT" + +// Author "Lynix" +// Description: "Simple \"module\" for testing" +// License: "Public domain" + +// Module Simple.Module +// std140 layout +struct _Simple_Module_Data +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// std140 layout +struct _Simple_Module_Block +{ + data: _Simple_Module_Data +} + +fn _Simple_Module_GetDataValue(data: _Simple_Module_Data) -> f32 +{ + return data.value; +} + +struct _Simple_Module_InputData +{ + @location(0) value: f32 +} + +struct _Simple_Module_OutputData +{ + @location(0) value: vec4 +} + +@group(0) @binding(0) var ExtData_block: _Simple_Module_Block; + +@fragment +fn main(input: _Simple_Module_InputData) -> _Simple_Module_OutputData +{ + var data: _Simple_Module_Data; + data.value = ExtData_block.data.value; + var value: f32 = ((_Simple_Module_GetDataValue(data)) * input.value) * 3.141592; + var output: _Simple_Module_OutputData; + output.value = vec4(value, value, value, value); + return output; +} +)"); } WHEN("Importing a simple module by name with renaming") @@ -1389,13 +1643,13 @@ struct Unused {} [export] struct InputData { - value: f32 + [location(0)] value: f32 } [export] struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } )"; @@ -1417,9 +1671,10 @@ external ExtData fn main(input: SimpleModule.InputData) -> SimpleModule.OutputData { let data = ExtData.block.data; + let value = SimpleModule.GetDataValue(data) * input.value * SimpleModule.Pi; let output: SimpleModule.OutputData; - output.value = SimpleModule.GetDataValue(data) * input.value * SimpleModule.Pi; + output.value = vec4[f32](value, value, value, value); return output; } )"; @@ -1462,7 +1717,7 @@ struct InputData_Simple_Module struct OutputData_Simple_Module { - float value; + vec4 value; }; // Main module @@ -1476,20 +1731,21 @@ layout(std140) uniform _nzslBindingExtData_block } ExtData_block; /**************** Inputs ****************/ -in float _nzslInvalue; +in float _nzslVarying0; // _nzslInvalue /*************** Outputs ***************/ -out float _nzslOutvalue; +layout(location = 0) out vec4 _nzslOutvalue; void main() { InputData_Simple_Module input_; - input_.value = _nzslInvalue; + input_.value = _nzslVarying0; Data_Simple_Module data; data.value = ExtData_block.data.value; + float value = ((GetDataValue_Simple_Module(data)) * input_.value) * 3.141592; OutputData_Simple_Module output_; - output_.value = ((GetDataValue_Simple_Module(data)) * input_.value) * 3.141592; + output_.value = vec4(value, value, value, value); _nzslOutvalue = output_.value; return; @@ -1528,12 +1784,12 @@ module _Simple_Module struct InputData { - value: f32 + [location(0)] value: f32 } struct OutputData { - value: f32 + [location(0)] value: vec4[f32] } } @@ -1548,8 +1804,9 @@ external ExtData fn main(input: SimpleModule.InputData) -> SimpleModule.OutputData { let data: SimpleModule.Data = ExtData.block.data; + let value: f32 = ((SimpleModule.GetDataValue(data)) * input.value) * SimpleModule.Pi; let output: SimpleModule.OutputData; - output.value = ((SimpleModule.GetDataValue(data)) * input.value) * SimpleModule.Pi; + output.value = vec4[f32](value, value, value, value); return output; } )"); @@ -1568,6 +1825,9 @@ OpVariable OpVariable OpVariable OpVariable +OpVariable +OpAccessChain +OpCopyMemory OpAccessChain OpLoad OpAccessChain @@ -1579,10 +1839,72 @@ OpAccessChain OpLoad OpFMul OpFMul +OpStore +OpLoad +OpLoad +OpLoad +OpLoad +OpCompositeConstruct OpAccessChain OpStore OpLoad +OpCompositeExtract +OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +// Author "Sir Lynix" +// Description: "Main file" +// License: "MIT" + +// Author "Lynix" +// Description: "Simple \"module\" for testing" +// License: "Public domain" + +// Module Simple.Module +// std140 layout +struct _Simple_Module_Data +{ + value: f32, + _padding0: f32, + _padding1: f32, + _padding2: f32 +} + +// std140 layout +struct _Simple_Module_Block +{ + data: _Simple_Module_Data +} + +fn _Simple_Module_GetDataValue(data: _Simple_Module_Data) -> f32 +{ + return data.value; +} + +struct _Simple_Module_InputData +{ + @location(0) value: f32 +} + +struct _Simple_Module_OutputData +{ + @location(0) value: vec4 +} + +@group(0) @binding(0) var ExtData_block: _Simple_Module_Block; + +@fragment +fn main(input: _Simple_Module_InputData) -> _Simple_Module_OutputData +{ + var data: _Simple_Module_Data; + data.value = ExtData_block.data.value; + var value: f32 = ((_Simple_Module_GetDataValue(data)) * input.value) * 3.141592; + var output: _Simple_Module_OutputData; + output.value = vec4(value, value, value, value); + return output; +} +)"); } } diff --git a/tests/src/Tests/ShaderUtils.cpp b/tests/src/Tests/ShaderUtils.cpp index 9f531eee..3ba6b0a9 100644 --- a/tests/src/Tests/ShaderUtils.cpp +++ b/tests/src/Tests/ShaderUtils.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -13,7 +14,8 @@ #include #include #include -#include +#include +#include namespace NAZARA_ANONYMOUS_NAMESPACE { @@ -462,6 +464,56 @@ void ExpectSPIRV(nzsl::Ast::Module& shaderModule, std::string_view expectedOutpu } } +void ExpectWGSL(const nzsl::Ast::Module& shader, std::string_view expectedOutput, const nzsl::BackendParameters& options, const nzsl::WgslWriter::Environment& env) +{ + NAZARA_USE_ANONYMOUS_NAMESPACE + + // Clone to avoid cross-test changes + nzsl::Ast::ModulePtr moduleClone = nzsl::Ast::Clone(shader); + + std::string source = SanitizeSource(expectedOutput); + + SECTION("Generating WGSL") + { + nzsl::Ast::ModulePtr sanitizedModule; + WHEN("Sanitizing a second time") + { + nzsl::Ast::TransformerContext context; + nzsl::Ast::ResolveTransformer resolver; + REQUIRE_NOTHROW(resolver.Transform(*moduleClone, context)); + } + nzsl::Ast::Module& targetModule = (sanitizedModule) ? *sanitizedModule : *moduleClone; + + nzsl::WgslWriter writer; + writer.SetEnv(env); + nzsl::WgslWriter::Output output = writer.Generate(targetModule, options); + + SECTION("Validating expected code") + { + std::string outputCode = SanitizeSource(output.code); + if (outputCode.find(source) == std::string::npos) + HandleSourceError("WGSL", source, outputCode); + } + + SECTION("Validating full WGSL code (using wgsl-validator)") + { + char* error = nullptr; + wgsl_validator_t* validator = wgsl_validator_create(); + Nz::CallOnExit cleanupOnExit([&] + { + if (error != nullptr) + wgsl_validator_free_error(error); + wgsl_validator_destroy(validator); + }); + if (wgsl_validator_validate(validator, output.code.c_str(), &error)) + { + INFO("full WGSL output:\n" << output.code << "\nerror:\n" << error); + REQUIRE(false); + } + } + } +} + std::filesystem::path GetResourceDir() { static std::filesystem::path resourceDir = [] diff --git a/tests/src/Tests/ShaderUtils.hpp b/tests/src/Tests/ShaderUtils.hpp index 95d15537..0c47b376 100644 --- a/tests/src/Tests/ShaderUtils.hpp +++ b/tests/src/Tests/ShaderUtils.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -34,6 +35,7 @@ void ExpectGLSL(nzsl::ShaderStageType stageType, nzsl::Ast::Module& shader, std: void ExpectGLSL(nzsl::Ast::Module& shader, std::string_view expectedOutput, const nzsl::BackendParameters& options = {}, const nzsl::GlslWriter::Environment& env = {}, bool testShaderCompilation = true); void ExpectNZSL(const nzsl::Ast::Module& shader, std::string_view expectedOutput); void ExpectSPIRV(nzsl::Ast::Module& shader, std::string_view expectedOutput, const nzsl::BackendParameters& options = {}, const nzsl::SpirvWriter::Environment& env = {}, bool outputParameter = false, const spvtools::ValidatorOptions& validatorOptions = {}); +void ExpectWGSL(const nzsl::Ast::Module& shader, std::string_view expectedOutput, const nzsl::BackendParameters& options = {}, const nzsl::WgslWriter::Environment& env = {}); std::filesystem::path GetResourceDir(); diff --git a/tests/src/Tests/SwizzleTests.cpp b/tests/src/Tests/SwizzleTests.cpp index 5050c786..896936d1 100644 --- a/tests/src/Tests/SwizzleTests.cpp +++ b/tests/src/Tests/SwizzleTests.cpp @@ -54,6 +54,15 @@ OpVectorShuffle OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var vec: vec4 = vec4(0.0, 1.0, 2.0, 3.0); + var value: vec3 = vec.xyz; +} +)"); } WHEN("writing") @@ -102,6 +111,16 @@ OpVectorShuffle OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var vec: vec4 = vec4(0.0, 0.0, 0.0, 0.0); + var _nzsl_cachedResult: vec3 = vec3(1.0, 2.0, 3.0); + vec = vec4(vec.x, _nzsl_cachedResult); +} +)"); } } @@ -157,6 +176,16 @@ OpStore OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var value: i32 = 42; + var vec: vec3 = vec3(value, value, value); + var vec_2: vec3 = vec3(47.0, 47.0, 47.0); +} +)"); } GIVEN("a function value") @@ -208,6 +237,17 @@ OpCompositeConstruct OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var _nzsl_cachedResult: f32 = max(2.0, 1.0); + var v: vec3 = vec3(_nzsl_cachedResult, _nzsl_cachedResult, _nzsl_cachedResult); + var _nzsl_cachedResult_2: f32 = min(2.0, 1.0); + var v2: vec3 = vec3(_nzsl_cachedResult_2, _nzsl_cachedResult_2, _nzsl_cachedResult_2); +} +)"); } } @@ -262,6 +302,15 @@ OpCompositeConstruct OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var vec: vec4 = vec4(0.0, 1.0, 2.0, 3.0); + var value: vec4 = vec4(vec.xyz.yz.y, vec.xyz.yz.y, vec.xyz.yz.y, vec.xyz.yz.y); +} +)"); } WHEN("writing") @@ -315,6 +364,18 @@ OpVectorShuffle OpStore OpReturn OpFunctionEnd)"); + + ExpectWGSL(*shaderModule, R"( +@fragment +fn main() +{ + var vec: vec4 = vec4(0.0, 1.0, 2.0, 3.0); + var _nzsl_cachedResult: f32 = 0.0; + vec = vec4(vec.x, vec.y, vec.z, _nzsl_cachedResult); + var _nzsl_cachedResult_2: vec2 = vec2(1.0, 0.0); + vec = vec4(vec.x, vec.y, _nzsl_cachedResult_2); +} +)"); } } } diff --git a/tests/xmake.lua b/tests/xmake.lua index e9e6897a..a015fc43 100644 --- a/tests/xmake.lua +++ b/tests/xmake.lua @@ -6,7 +6,7 @@ if has_config("tests") then add_defines("CATCH_CONFIG_NO_POSIX_SIGNALS") end - add_requires("catch2 3", "spirv-tools", "tiny-process-library") + add_requires("catch2 3", "wgsl-validator", "spirv-tools", "tiny-process-library") add_requires("glslang", { configs = { rtti = has_config("ubsan") } }) -- ubsan requires rtti add_includedirs("src") @@ -19,7 +19,7 @@ if has_config("tests") then add_files("src/**.cpp") add_deps("nzsl") - add_packages("catch2", "glslang", "spirv-tools") + add_packages("catch2", "glslang", "wgsl-validator", "spirv-tools") if has_config("with_nzslc") then add_deps("nzslc", { links = {} }) diff --git a/xmake.lua b/xmake.lua index 0e963221..e08fe0a2 100644 --- a/xmake.lua +++ b/xmake.lua @@ -28,7 +28,7 @@ end ----------------------- Dependencies ----------------------- add_repositories("nazara-engine-repo https://github.com/NazaraEngine/xmake-repo") -add_requires("fmt", { system = false }) +add_requires("fmt 12.0.0", { system = false }) add_requires("nazarautils", "fast_float", "frozen", "lz4 >=1.9", "ordered_map") if has_config("fs_watcher") then