diff --git a/build.sbt b/build.sbt index 2fef2351..8005cccd 100644 --- a/build.sbt +++ b/build.sbt @@ -36,6 +36,7 @@ lazy val vulkanNatives = else Seq.empty lazy val commonSettings = Seq( + scalacOptions ++= Seq("-feature", "-deprecation", "-unchecked", "-language:implicitConversions"), libraryDependencies ++= Seq( "dev.zio" % "izumi-reflect_3" % "2.3.10", "com.lihaoyi" % "pprint_3" % "0.9.0", diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala index b4a6a302..2886e837 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala @@ -1,12 +1,8 @@ package io.computenode.cyfra.spirv -import io.computenode.cyfra.dsl.Expression.{E, FunctionCall} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Expression.E import scala.collection.mutable -import scala.quoted.Expr private[cyfra] object BlockBuilder: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala index a8fb18d2..974f045f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.spirv import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala index 38809183..1f8c4cb6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala @@ -2,33 +2,30 @@ package io.computenode.cyfra.spirv import java.nio.charset.StandardCharsets -private[cyfra] object Opcodes { +private[cyfra] object Opcodes: def intToBytes(i: Int): List[Byte] = List[Byte]((i >>> 24).asInstanceOf[Byte], (i >>> 16).asInstanceOf[Byte], (i >>> 8).asInstanceOf[Byte], (i >>> 0).asInstanceOf[Byte]) - private[cyfra] trait Words { + private[cyfra] trait Words: def toWords: List[Byte] def length: Int - } - private[cyfra] case class Word(bytes: Array[Byte]) extends Words { + private[cyfra] case class Word(bytes: Array[Byte]) extends Words: def toWords: List[Byte] = bytes.toList def length = 1 override def toString = s"Word(${bytes.mkString(", ")}${if bytes.length == 4 then s" [i = ${BigInt(bytes).toInt}])" else ""}" - } - private[cyfra] case class WordVariable(name: String) extends Words { + private[cyfra] case class WordVariable(name: String) extends Words: def toWords: List[Byte] = List(-1, -1, -1, -1) def length = 1 - } - private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words { + private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words: override def toWords: List[Byte] = code.toWords.take(2) ::: intToBytes(length).reverse.take(2) ::: operands.flatMap(_.toWords) @@ -41,38 +38,32 @@ private[cyfra] object Opcodes { }) override def toString: String = s"${code.mnemo} ${operands.mkString(", ")}" - } - private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words { + private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words: override def toWords: List[Byte] = intToBytes(opcode).reverse override def length: Int = 1 - } - private[cyfra] case class Text(text: String) extends Words { - override def toWords: List[Byte] = { + private[cyfra] case class Text(text: String) extends Words: + override def toWords: List[Byte] = val textBytes = text.getBytes(StandardCharsets.UTF_8).toList val complBytesLength = 4 - (textBytes.length % 4) val complBytes = List.fill[Byte](complBytesLength)(0) textBytes ::: complBytes - } override def length: Int = toWords.length / 4 - } - private[cyfra] case class IntWord(i: Int) extends Words { + private[cyfra] case class IntWord(i: Int) extends Words: override def toWords: List[Byte] = intToBytes(i).reverse override def length: Int = 1 - } - private[cyfra] case class ResultRef(result: Int) extends Words { + private[cyfra] case class ResultRef(result: Int) extends Words: override def toWords: List[Byte] = intToBytes(result).reverse override def length: Int = 1 override def toString: String = s"%$result" - } val MagicNumber = Code("MagicNumber", 0x07230203) val Version = Code("Version", 0x00010000) @@ -81,16 +72,15 @@ private[cyfra] object Opcodes { val OpCodeMask = Code("OpCodeMask", 0xffff) val WordCountShift = Code("WordCountShift", 16) - object SourceLanguage { + object SourceLanguage: val Unknown = Code("Unknown", 0) val ESSL = Code("ESSL", 1) val GLSL = Code("GLSL", 2) val OpenCL_C = Code("OpenCL_C", 3) val OpenCL_CPP = Code("OpenCL_CPP", 4) val HLSL = Code("HLSL", 5) - } - object ExecutionModel { + object ExecutionModel: val Vertex = Code("Vertex", 0) val TessellationControl = Code("TessellationControl", 1) val TessellationEvaluation = Code("TessellationEvaluation", 2) @@ -98,21 +88,18 @@ private[cyfra] object Opcodes { val Fragment = Code("Fragment", 4) val GLCompute = Code("GLCompute", 5) val Kernel = Code("Kernel", 6) - } - object AddressingModel { + object AddressingModel: val Logical = Code("Logical", 0) val Physical32 = Code("Physical32", 1) val Physical64 = Code("Physical64", 2) - } - object MemoryModel { + object MemoryModel: val Simple = Code("Simple", 0) val GLSL450 = Code("GLSL450", 1) val OpenCL = Code("OpenCL", 2) - } - object ExecutionMode { + object ExecutionMode: val Invocations = Code("Invocations", 0) val SpacingEqual = Code("SpacingEqual", 1) val SpacingFractionalEven = Code("SpacingFractionalEven", 2) @@ -150,9 +137,8 @@ private[cyfra] object Opcodes { val SubgroupsPerWorkgroup = Code("SubgroupsPerWorkgroup", 36) val PostDepthCoverage = Code("PostDepthCoverage", 4446) val StencilRefReplacingEXT = Code("StencilRefReplacingEXT", 5027) - } - object StorageClass { + object StorageClass: val UniformConstant = Code("UniformConstant", 0) val Input = Code("Input", 1) val Uniform = Code("Uniform", 2) @@ -166,9 +152,8 @@ private[cyfra] object Opcodes { val AtomicCounter = Code("AtomicCounter", 10) val Image = Code("Image", 11) val StorageBuffer = Code("StorageBuffer", 12) - } - object Dim { + object Dim: val Dim1D = Code("Dim1D", 0) val Dim2D = Code("Dim2D", 1) val Dim3D = Code("Dim3D", 2) @@ -176,22 +161,19 @@ private[cyfra] object Opcodes { val Rect = Code("Rect", 4) val Buffer = Code("Buffer", 5) val SubpassData = Code("SubpassData", 6) - } - object SamplerAddressingMode { + object SamplerAddressingMode: val None = Code("None", 0) val ClampToEdge = Code("ClampToEdge", 1) val Clamp = Code("Clamp", 2) val Repeat = Code("Repeat", 3) val RepeatMirrored = Code("RepeatMirrored", 4) - } - object SamplerFilterMode { + object SamplerFilterMode: val Nearest = Code("Nearest", 0) val Linear = Code("Linear", 1) - } - object ImageFormat { + object ImageFormat: val Unknown = Code("Unknown", 0) val Rgba32f = Code("Rgba32f", 1) val Rgba16f = Code("Rgba16f", 2) @@ -232,9 +214,8 @@ private[cyfra] object Opcodes { val Rg8ui = Code("Rg8ui", 37) val R16ui = Code("R16ui", 38) val R8ui = Code("R8ui", 39) - } - object ImageChannelOrder { + object ImageChannelOrder: val R = Code("R", 0) val A = Code("A", 1) val RG = Code("RG", 2) @@ -255,9 +236,8 @@ private[cyfra] object Opcodes { val sRGBA = Code("sRGBA", 17) val sBGRA = Code("sBGRA", 18) val ABGR = Code("ABGR", 19) - } - object ImageChannelDataType { + object ImageChannelDataType: val SnormInt8 = Code("SnormInt8", 0) val SnormInt16 = Code("SnormInt16", 1) val UnormInt8 = Code("UnormInt8", 2) @@ -275,9 +255,8 @@ private[cyfra] object Opcodes { val Float = Code("Float", 14) val UnormInt24 = Code("UnormInt24", 15) val UnormInt101010_2 = Code("UnormInt101010_2", 16) - } - object ImageOperandsShift { + object ImageOperandsShift: val Bias = Code("Bias", 0) val Lod = Code("Lod", 1) val Grad = Code("Grad", 2) @@ -286,9 +265,8 @@ private[cyfra] object Opcodes { val ConstOffsets = Code("ConstOffsets", 5) val Sample = Code("Sample", 6) val MinLod = Code("MinLod", 7) - } - object ImageOperandsMask { + object ImageOperandsMask: val MaskNone = Code("MaskNone", 0) val Bias = Code("Bias", 0x00000001) val Lod = Code("Lod", 0x00000002) @@ -298,44 +276,38 @@ private[cyfra] object Opcodes { val ConstOffsets = Code("ConstOffsets", 0x00000020) val Sample = Code("Sample", 0x00000040) val MinLod = Code("MinLod", 0x00000080) - } - object FPFastMathModeShift { + object FPFastMathModeShift: val NotNaN = Code("NotNaN", 0) val NotInf = Code("NotInf", 1) val NSZ = Code("NSZ", 2) val AllowRecip = Code("AllowRecip", 3) val Fast = Code("Fast", 4) - } - object FPFastMathModeMask { + object FPFastMathModeMask: val MaskNone = Code("MaskNone", 0) val NotNaN = Code("NotNaN", 0x00000001) val NotInf = Code("NotInf", 0x00000002) val NSZ = Code("NSZ", 0x00000004) val AllowRecip = Code("AllowRecip", 0x00000008) val Fast = Code("Fast", 0x00000010) - } - object FPRoundingMode { + object FPRoundingMode: val RTE = Code("RTE", 0) val RTZ = Code("RTZ", 1) val RTP = Code("RTP", 2) val RTN = Code("RTN", 3) - } - object LinkageType { + object LinkageType: val Export = Code("Export", 0) val Import = Code("Import", 1) - } - object AccessQualifier { + object AccessQualifier: val ReadOnly = Code("ReadOnly", 0) val WriteOnly = Code("WriteOnly", 1) val ReadWrite = Code("ReadWrite", 2) - } - object FunctionParameterAttribute { + object FunctionParameterAttribute: val Zext = Code("Zext", 0) val Sext = Code("Sext", 1) val ByVal = Code("ByVal", 2) @@ -344,9 +316,8 @@ private[cyfra] object Opcodes { val NoCapture = Code("NoCapture", 5) val NoWrite = Code("NoWrite", 6) val NoReadWrite = Code("NoReadWrite", 7) - } - object Decoration { + object Decoration: val RelaxedPrecision = Code("RelaxedPrecision", 0) val SpecId = Code("SpecId", 1) val Block = Code("Block", 2) @@ -396,9 +367,8 @@ private[cyfra] object Opcodes { val PassthroughNV = Code("PassthroughNV", 5250) val ViewportRelativeNV = Code("ViewportRelativeNV", 5252) val SecondaryViewportRelativeNV = Code("SecondaryViewportRelativeNV", 5256) - } - object BuiltIn { + object BuiltIn: val Position = Code("Position", 0) val PointSize = Code("PointSize", 1) val ClipDistance = Code("ClipDistance", 3) @@ -463,50 +433,43 @@ private[cyfra] object Opcodes { val SecondaryViewportMaskNV = Code("SecondaryViewportMaskNV", 5258) val PositionPerViewNV = Code("PositionPerViewNV", 5261) val ViewportMaskPerViewNV = Code("ViewportMaskPerViewNV", 5262) - } - object SelectionControlShift { + object SelectionControlShift: val Flatten = Code("Flatten", 0) val DontFlatten = Code("DontFlatten", 1) - } - object SelectionControlMask { + object SelectionControlMask: val MaskNone = Code("MaskNone", 0) val Flatten = Code("Flatten", 0x00000001) val DontFlatten = Code("DontFlatten", 0x00000002) - } - object LoopControlShift { + object LoopControlShift: val Unroll = Code("Unroll", 0) val DontUnroll = Code("DontUnroll", 1) val DependencyInfinite = Code("DependencyInfinite", 2) val DependencyLength = Code("DependencyLength", 3) - } - object LoopControlMask { + object LoopControlMask: val MaskNone = Code("MaskNone", 0) val Unroll = Code("Unroll", 0x00000001) val DontUnroll = Code("DontUnroll", 0x00000002) val DependencyInfinite = Code("DependencyInfinite", 0x00000004) val DependencyLength = Code("DependencyLength", 0x00000008) - } - object FunctionControlShift { + object FunctionControlShift: val Inline = Code("Inline", 0) val DontInline = Code("DontInline", 1) val Pure = Code("Pure", 2) val Const = Code("Const", 3) - } - object FunctionControlMask { + object FunctionControlMask: val MaskNone = Code("MaskNone", 0) val Inline = Code("Inline", 0x00000001) val DontInline = Code("DontInline", 0x00000002) val Pure = Code("Pure", 0x00000004) val Const = Code("Const", 0x00000008) - } - object MemorySemanticsShift { + object MemorySemanticsShift: val Acquire = Code("Acquire", 1) val Release = Code("Release", 2) val AcquireRelease = Code("AcquireRelease", 3) @@ -517,9 +480,8 @@ private[cyfra] object Opcodes { val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 9) val AtomicCounterMemory = Code("AtomicCounterMemory", 10) val ImageMemory = Code("ImageMemory", 11) - } - object MemorySemanticsMask { + object MemorySemanticsMask: val MaskNone = Code("MaskNone", 0) val Acquire = Code("Acquire", 0x00000002) val Release = Code("Release", 0x00000004) @@ -531,51 +493,43 @@ private[cyfra] object Opcodes { val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 0x00000200) val AtomicCounterMemory = Code("AtomicCounterMemory", 0x00000400) val ImageMemory = Code("ImageMemory", 0x00000800) - } - object MemoryAccessShift { + object MemoryAccessShift: val Volatile = Code("Volatile", 0) val Aligned = Code("Aligned", 1) val Nontemporal = Code("Nontemporal", 2) - } - object MemoryAccessMask { + object MemoryAccessMask: val MaskNone = Code("MaskNone", 0) val Volatile = Code("Volatile", 0x00000001) val Aligned = Code("Aligned", 0x00000002) val Nontemporal = Code("Nontemporal", 0x00000004) - } - object Scope { + object Scope: val CrossDevice = Code("CrossDevice", 0) val Device = Code("Device", 1) val Workgroup = Code("Workgroup", 2) val Subgroup = Code("Subgroup", 3) val Invocation = Code("Invocation", 4) - } - object GroupOperation { + object GroupOperation: val Reduce = Code("Reduce", 0) val InclusiveScan = Code("InclusiveScan", 1) val ExclusiveScan = Code("ExclusiveScan", 2) - } - object KernelEnqueueFlags { + object KernelEnqueueFlags: val NoWait = Code("NoWait", 0) val WaitKernel = Code("WaitKernel", 1) val WaitWorkGroup = Code("WaitWorkGroup", 2) - } - object KernelProfilingInfoShift { + object KernelProfilingInfoShift: val CmdExecTime = Code("CmdExecTime", 0) - } - object KernelProfilingInfoMask { + object KernelProfilingInfoMask: val MaskNone = Code("MaskNone", 0) val CmdExecTime = Code("CmdExecTime", 0x00000001) - } - object Capability { + object Capability: val Matrix = Code("Matrix", 0) val Shader = Code("Shader", 1) val Geometry = Code("Geometry", 2) @@ -664,9 +618,8 @@ private[cyfra] object Opcodes { val SubgroupShuffleINTEL = Code("SubgroupShuffleINTEL", 5568) val SubgroupBufferBlockIOINTEL = Code("SubgroupBufferBlockIOINTEL", 5569) val SubgroupImageBlockIOINTEL = Code("SubgroupImageBlockIOINTEL", 5570) - } - object Op { + object Op: val OpNop = Code("OpNop", 0) val OpUndef = Code("OpUndef", 1) val OpSourceContinued = Code("OpSourceContinued", 2) @@ -995,9 +948,8 @@ private[cyfra] object Opcodes { val OpSubgroupBlockWriteINTEL = Code("OpSubgroupBlockWriteINTEL", 5576) val OpSubgroupImageBlockReadINTEL = Code("OpSubgroupImageBlockReadINTEL", 5577) val OpSubgroupImageBlockWriteINTEL = Code("OpSubgroupImageBlockWriteINTEL", 5578) - } - object GlslOp { + object GlslOp: val Round = Code("Round", 1) val RoundEven = Code("RoundEven", 2) val Trunc = Code("Trunc", 3) @@ -1078,6 +1030,3 @@ private[cyfra] object Opcodes { val NMin = Code("NMin", 79) val NMax = Code("NMax", 80) val NClamp = Code("NClamp", 81) - } - -} diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala index 122f6505..9fe1b386 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala @@ -2,7 +2,6 @@ package io.computenode.cyfra.spirv import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.spirv.Context.initialContext import io.computenode.cyfra.spirv.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.{LTag, LightTypeTag} @@ -37,12 +36,11 @@ private[cyfra] object SpirvTypes: type Vec3C[T <: Value] = Vec3[T] type Vec4C[T <: Value] = Vec4[T] - def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match { + def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) - } def vecSize(tag: LightTypeTag): Int = tag match case v if v <:< LVec2Tag => 2 @@ -59,19 +57,17 @@ private[cyfra] object SpirvTypes: def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) - def toWord(tpe: Tag[?], value: Any): Words = tpe match { + def toWord(tpe: Tag[?], value: Any): Words = tpe match case t if t == Int32Tag => IntWord(value.asInstanceOf[Int]) case t if t == UInt32Tag => IntWord(value.asInstanceOf[Int]) case t if t == Float32Tag => - val fl = value match { + val fl = value match case fl: Float => fl case dl: Double => dl.toFloat case il: Int => il.toFloat - } Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) - } def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) @@ -99,9 +95,9 @@ private[cyfra] object SpirvTypes: ctx.copy( valueTypeMap = ctx.valueTypeMap ++ Map( valType.tag -> typeDefIndex, - (summon[LTag[Vec2C]].tag.combine(valType.tag)) -> (typeDefIndex + 4), - (summon[LTag[Vec3C]].tag.combine(valType.tag)) -> (typeDefIndex + 5), - (summon[LTag[Vec4C]].tag.combine(valType.tag)) -> (typeDefIndex + 11), + summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4), + summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5), + summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11), ), funPointerTypeMap = ctx.funPointerTypeMap ++ Map( typeDefIndex -> (typeDefIndex + 1), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala index 4d58ae7f..07ae9aab 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala @@ -1,29 +1,26 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.* -import io.computenode.cyfra.spirv.Opcodes.* -import izumi.reflect.Tag -import izumi.reflect.macrortti.{LTag, LTagK, LightTypeTag} -import org.lwjgl.BufferUtils -import SpirvProgramCompiler.* -import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.Value.Scalar import io.computenode.cyfra.dsl.struct.GStruct.* import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctions import io.computenode.cyfra.spirv.compilers.GStructCompiler.* -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.{compileFunctions, defineFunctionTypes} +import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.* +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag +import org.lwjgl.BufferUtils import java.nio.ByteBuffer import scala.annotation.tailrec import scala.collection.mutable -import scala.math.random import scala.runtime.stdLibPatches.Predef.summon -import scala.util.Random private[cyfra] object DSLCompiler: @@ -80,10 +77,9 @@ private[cyfra] object DSLCompiler: decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: constDefs ::: varDefs ::: main ::: fnDefs - val fullCode = code.map { + val fullCode = code.map: case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) case x => x - } val bytes = fullCode.flatMap(_.toWords).toArray BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala index 8ef41c8d..1a8cd62b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala @@ -1,24 +1,22 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.spirv.Opcodes.* -import ExtFunctionCompiler.compileExtFunctionCall -import FunctionCompiler.compileFunctionCall -import WhenCompiler.compileWhen -import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GArray.GArrayElem import io.computenode.cyfra.dsl.collections.GSeq import io.computenode.cyfra.dsl.macros.Source import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.SpirvTypes.* +import io.computenode.cyfra.spirv.compilers.ExtFunctionCompiler.compileExtFunctionCall +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctionCall +import io.computenode.cyfra.spirv.compilers.WhenCompiler.compileWhen import io.computenode.cyfra.spirv.{BlockBuilder, Context} -import io.computenode.cyfra.dsl.collections.GArray.GArrayElem import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* import scala.annotation.tailrec -import scala.collection.immutable.List as expr private[cyfra] object ExpressionCompiler: @@ -38,14 +36,13 @@ private[cyfra] object ExpressionCompiler: private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = tpe match { + val subOpcode = tpe match case i if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag || (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => binaryOpOpcode(bexpr)._1 case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => binaryOpOpcode(bexpr)._2 - } val instructions = List( Instruction( subOpcode, @@ -58,14 +55,13 @@ private[cyfra] object ExpressionCompiler: private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) = val tpe = cexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val tfOpcode = (cexpr.fromTag, cexpr) match { + val tfOpcode = (cexpr.fromTag, cexpr) match case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast - } val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) @@ -81,35 +77,34 @@ private[cyfra] object ExpressionCompiler: private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = bexpr match { + val subOpcode = bexpr match case _: BitwiseAnd[?] => Op.OpBitwiseAnd case _: BitwiseOr[?] => Op.OpBitwiseOr case _: BitwiseXor[?] => Op.OpBitwiseXor case _: BitwiseNot[?] => Op.OpNot case _: ShiftLeft[?] => Op.OpShiftLeftLogical case _: ShiftRight[?] => Op.OpShiftRightLogical - } val instructions = List( Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))), ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = { + def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = @tailrec - def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = { + def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = if exprs.isEmpty then (acc, ctx) - else { + else val expr = exprs.head if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc) - else { + else val name: Option[String] = expr.of match case Some(v) => Some(v.source.name) case _ => None - val (instructions, updatedCtx) = expr match { + val (instructions, updatedCtx) = expr match case c @ Const(x) => val constRef = ctx.constRefs((c.tag, x)) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef)) @@ -322,6 +317,7 @@ private[cyfra] object ExpressionCompiler: GSeqCompiler.compileFold(fd, ctx) case cs: ComposeStruct[?] => + // noinspection ScalaRedundantCast val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]] val fields = cs.fields val insns: List[Instruction] = List( @@ -365,12 +361,7 @@ private[cyfra] object ExpressionCompiler: (insns, updatedContext) case ph: PhantomExpression[?] => (List(), ctx) - } val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap) compileExpressions(exprs.tail, ctxWithName, acc ::: instructions) - } - } - } val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet) compileExpressions(sortedTree, ctx, Nil) - } diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala index 3d3f916f..21c04283 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala @@ -1,13 +1,12 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.library.Functions.FunctionName -import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.dsl.Expression import io.computenode.cyfra.dsl.library.Functions +import io.computenode.cyfra.dsl.library.Functions.FunctionName import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction private[cyfra] object ExtFunctionCompiler: private val fnOpMap: Map[FunctionName, Code] = Map( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala index e8910c71..3e76f60f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala @@ -1,16 +1,9 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.dsl.Expression -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.library.Functions -import io.computenode.cyfra.dsl.library.Functions.FunctionName import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.bubbleUpVars import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala index 3a44a2cf..e635c4c5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala @@ -3,11 +3,10 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.collections.GSeq import io.computenode.cyfra.dsl.collections.GSeq.* +import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.{BlockBuilder, Context} -import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* +import izumi.reflect.Tag private[cyfra] object GSeqCompiler: @@ -49,7 +48,7 @@ private[cyfra] object GSeqCompiler: def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) = val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef)) - seqExprs match { + seqExprs match case Nil => // No more transformations, so reduce ops now val resultRef = context.nextResultId val forReduceCtx = withElemRefCtx @@ -72,7 +71,7 @@ private[cyfra] object GSeqCompiler: (instructions, ctx.joinNested(reduceCtx)) case (op, dExpr) :: tail => - op match { + op match case MapOp(_) => val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) val newElemRef = mapContext.exprRefs(dExpr.treeid) @@ -105,8 +104,6 @@ private[cyfra] object GSeqCompiler: Instruction(Op.OpLabel, List(ResultRef(trueLabel))), ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult"))) - } - } val seqExprs = fold.seq.elemOps.zip(fold.seqExprs) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala index d495444d..fe3faacc 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.dsl.struct.GStructSchema.* -import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala index 32828fc2..8d16743c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala @@ -14,12 +14,11 @@ import izumi.reflect.Tag private[cyfra] object SpirvProgramCompiler: def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) = - exprs.partition { + exprs.partition: case Instruction(Op.OpVariable, _) => true case _ => false - } - def compileMain(tree: Value, resultType: Tag[?], ctx: Context): (List[Words], Context) = { + def compileMain(tree: Value, resultType: Tag[?], ctx: Context): (List[Words], Context) = val init = List( Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), @@ -49,7 +48,7 @@ private[cyfra] object SpirvProgramCompiler: List( ResultRef(codeCtx.uniformPointerMap(codeCtx.valueTypeMap(resultType.tag))), ResultRef(codeCtx.nextResultId), - ResultRef(codeCtx.outBufferBlocks(0).blockVarRef), + ResultRef(codeCtx.outBufferBlocks.head.blockVarRef), ResultRef(codeCtx.constRefs((Int32Tag, 0))), ResultRef(codeCtx.workerIndexRef), ), @@ -59,7 +58,6 @@ private[cyfra] object SpirvProgramCompiler: Instruction(Op.OpFunctionEnd, List()), ) (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) - } def getNameDecorations(ctx: Context): List[Instruction] = val funNames = ctx.functions.map { case (id, fn) => @@ -96,23 +94,21 @@ private[cyfra] object SpirvProgramCompiler: Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil - def defineVoids(context: Context): (List[Words], Context) = { + def defineVoids(context: Context): (List[Words], Context) = val voidDef = List[Words]( Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))), Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))), ) val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) (voidDef, ctxWithVoid) - } - def initAndDecorateUniforms(ins: List[Tag[?]], outs: List[Tag[?]], context: Context): (List[Words], List[Words], Context) = { + def initAndDecorateUniforms(ins: List[Tag[?]], outs: List[Tag[?]], context: Context): (List[Words], List[Words], Context) = val (inDecor, inDef, inCtx) = createAndInitBlocks(ins, in = true, context) val (outDecor, outDef, outCtx) = createAndInitBlocks(outs, in = false, inCtx) val (voidsDef, voidCtx) = defineVoids(outCtx) (inDecor ::: outDecor, voidsDef ::: inDef ::: outDef, voidCtx) - } - def createInvocationId(context: Context): (List[Words], Context) = { + def createInvocationId(context: Context): (List[Words], Context) = val definitionInstructions = List( Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), @@ -129,9 +125,8 @@ private[cyfra] object SpirvProgramCompiler: ), ) (definitionInstructions, context.copy(nextResultId = context.nextResultId + 3)) - } - def createAndInitBlocks(blocks: List[Tag[?]], in: Boolean, context: Context): (List[Words], List[Words], Context) = { + def createAndInitBlocks(blocks: List[Tag[?]], in: Boolean, context: Context): (List[Words], List[Words], Context) = val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { case ((decAcc, insnAcc, ctx), tpe) => val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, ctx.nextBinding) @@ -159,7 +154,6 @@ private[cyfra] object SpirvProgramCompiler: ) } (decoration, definition, newContext) - } def getBlockNames(context: Context, uniformSchema: GStructSchema[?]): List[Words] = def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] = @@ -215,7 +209,7 @@ private[cyfra] object SpirvProgramCompiler: ) val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) - def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = { + def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = val consts = (exprs.collect { case c @ Const(x) => (c.tag, x) @@ -234,10 +228,9 @@ private[cyfra] object SpirvProgramCompiler: withBool, newC.copy( nextResultId = newC.nextResultId + 2, - constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> (newC.nextResultId), (GBooleanTag, false) -> (newC.nextResultId + 1)), + constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)), ), ) - } def defineVarNames(ctx: Context): (List[Words], Context) = ( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala index 9a889b8f..3b3d1c13 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala @@ -1,19 +1,17 @@ package io.computenode.cyfra.spirv.compilers -import ExpressionCompiler.compileBlock -import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.control.When.WhenExpr -import io.computenode.cyfra.spirv.{BlockBuilder, Context} +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* private[cyfra] object WhenCompiler: - def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = { + def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) = - (conditions, thenCodes) match { + (conditions, thenCodes) match case (Nil, Nil) => val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx) val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid)))) @@ -42,7 +40,6 @@ private[cyfra] object WhenCompiler: ), postCtx.joinNested(elseCtx), ) - } val resultVar = ctx.nextResultId val resultLoaded = ctx.nextResultId + 1 @@ -58,4 +55,3 @@ private[cyfra] object WhenCompiler: List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) ::: caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar)))) (instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded))) - } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala index 707b637b..52b8b844 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala @@ -18,7 +18,7 @@ trait Expression[T <: Value: Tag] extends Product: .map(e => s"#${e.treeid}") .mkString("[", ", ", "]") override def toString: String = s"${this.productPrefix}(${of.fold("")(v => s"name = ${v.source}, ")}children=$childrenStrings, id=$treeid)" - private def exploreDeps(children: List[Any]): (List[Expression[?]], List[Scope[?]]) = (for (elem <- children) yield elem match { + private def exploreDeps(children: List[Any]): (List[Expression[?]], List[Scope[?]]) = (for elem <- children yield elem match { case b: Scope[?] => (None, Some(b)) case x: Expression[?] => @@ -46,10 +46,9 @@ object Expression: type E[T <: Value] = Expression[T] case class Negate[T <: Value: Tag](a: T) extends Expression[T] - sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T] { + sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T]: def a: T def b: T - } case class Sum[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] case class Diff[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] case class Mul[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T] @@ -59,10 +58,9 @@ object Expression: case class DotProd[S <: Scalar: Tag, V <: Vec[S]](a: V, b: V) extends Expression[S] sealed trait BitwiseOpExpression[T <: Scalar: Tag] extends Expression[T] - sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T] { + sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T]: def a: T def b: T - } case class BitwiseAnd[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] case class BitwiseOr[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] case class BitwiseXor[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] @@ -70,11 +68,10 @@ object Expression: case class ShiftLeft[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] case class ShiftRight[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] - sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean] { + sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean]: def operandTag = summon[Tag[T]] def a: T def b: T - } case class GreaterThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] case class LessThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] case class GreaterThanEqual[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] @@ -87,20 +84,17 @@ object Expression: case class ExtractScalar[V <: Vec[?]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S] - sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T] { + sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]: def fromTag: Tag[F] = summon[Tag[F]] def a: F - } case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32] case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32] case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32] - sealed trait Const[T <: Scalar: Tag] extends Expression[T] { + sealed trait Const[T <: Scalar: Tag] extends Expression[T]: def value: Any - } - object Const { + object Const: def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value) - } case class ConstFloat32(value: Float) extends Const[Float32] case class ConstInt32(value: Int) extends Const[Int32] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala index a1963157..090797bf 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala @@ -7,7 +7,6 @@ import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag import io.computenode.cyfra.dsl.Value.FromExpr -class GArray2D[T <: Value: Tag: FromExpr](width: Int, val arr: GArray[T]) { +class GArray2D[T <: Value: Tag: FromExpr](width: Int, val arr: GArray[T]): def at(x: Int32, y: Int32)(using Source): T = arr.at(y * width + x) -} diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala index cb380b30..103c5050 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala @@ -10,9 +10,6 @@ import io.computenode.cyfra.dsl.macros.Source import io.computenode.cyfra.dsl.{Expression, Value} import izumi.reflect.Tag -import java.util.Base64 -import scala.util.Random - class GSeq[T <: Value: Tag: FromExpr]( val uninitSource: Expression[?] => GSeqStream[?], val elemOps: List[GSeq.ElemOp[?]], @@ -64,19 +61,16 @@ object GSeq: def of[T <: Value: Tag: FromExpr](xs: List[T]) = GSeq .gen[Int32](0, _ + 1) - .map { i => - val first = when(i === 0) { - xs(0) - } + .map: i => + val first = when(i === 0): + xs.head (if xs.length == 1 then first else - xs.init.zipWithIndex.tail.foldLeft(first) { case (acc, (x, j)) => - acc.elseWhen(i === j) { - x - } - } + xs.init.zipWithIndex.tail.foldLeft(first): + case (acc, (x, j)) => + acc.elseWhen(i === j): + x ).otherwise(xs.last) - } .limit(xs.length) case class CurrentElem[T <: Value: Tag](tid: Int) extends PhantomExpression[T] with CustomTreeId: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala index 0e34626c..5b1f0013 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala @@ -10,19 +10,17 @@ import scala.annotation.targetName object Color: - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = { + def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - } // https://www.youtube.com/shorts/TH3OTy5fTog def igPallette(brightness: Vec3[Float32], contrast: Vec3[Float32], freq: Vec3[Float32], offsets: Vec3[Float32], f: Float32): Vec3[Float32] = brightness addV (contrast mulV cos(((freq * f) addV offsets) * 2f * math.Pi.toFloat)) - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = { + def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - } type InterpolationTheme = (Vec3[Float32], Vec3[Float32], Vec3[Float32]) object InterpolationThemes: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala index 0f2b3a4f..0de27564 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala @@ -2,13 +2,11 @@ package io.computenode.cyfra.dsl.library import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag -import scala.annotation.targetName - object Functions: sealed class FunctionName diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala index b8249e51..57f50add 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala @@ -1,12 +1,10 @@ package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import Functions.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} import io.computenode.cyfra.dsl.control.When.when - -import scala.concurrent.duration.DurationInt +import io.computenode.cyfra.dsl.library.Functions.* object Math3D: def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w @@ -14,22 +12,20 @@ object Math3D: def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) val cosX = -(normal dot incident) - when(n1 > n2) { + when(n1 > n2): val n = n1 / n2 val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f) { + when(sinT2 > 1f): f90 - } otherwise { + .otherwise: val cosX2 = sqrt(1.0f - sinT2) val x = 1.0f - cosX2 val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } - } otherwise { + .otherwise: val x = 1.0f - cosX val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala index f5ad184c..f84122e1 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala @@ -13,10 +13,9 @@ object FnCall: implicit inline def generate: FnCall = ${ fnCallImpl } - def fnCallImpl(using Quotes): Expr[FnCall] = { + def fnCallImpl(using Quotes): Expr[FnCall] = import quotes.reflect.* resolveFnCall - } case class FnIdentifier(shortName: String, fullName: String, args: List[LightTypeTag]) @@ -30,18 +29,17 @@ object FnCall: val name = Util.getName(ownerDef) val ddOwner = actualOwner(ownerDef) val ownerName = ddOwner.map(d => d.fullName).getOrElse("unknown") - ownerDef.tree match { + ownerDef.tree match case dd: DefDef if isPure(dd) => - val paramTerms: List[Term] = for { + val paramTerms: List[Term] = for paramGroup <- dd.paramss param <- paramGroup.params - } yield Ref(param.symbol) + yield Ref(param.symbol) val paramExprs: List[Expr[Value]] = paramTerms.map(_.asExpr.asInstanceOf[Expr[Value]]) val paramList = Expr.ofList(paramExprs) '{ FnCall(${ Expr(name) }, ${ Expr(ownerName) }, ${ paramList }) } case _ => quotes.reflect.report.errorAndAbort(s"Expected pure function. Found: $ownerDef") - } case None => quotes.reflect.report.errorAndAbort(s"Expected pure function") def isPure(using Quotes)(defdef: quotes.reflect.DefDef): Boolean = diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala index e6212397..9acf9f39 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala @@ -13,11 +13,10 @@ object Source: implicit inline def generate: Source = ${ sourceImpl } - def sourceImpl(using Quotes): Expr[Source] = { + def sourceImpl(using Quotes): Expr[Source] = import quotes.reflect.* val name = valueName '{ Source(${ name }) } - } def valueName(using Quotes): Expr[String] = import quotes.reflect.* @@ -29,14 +28,13 @@ object Source: case None => Expr("unknown") - def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = { + def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = import quotes.reflect.* var owner0 = owner while skipIf(owner0) do if owner0 == Symbol.noSymbol then return None owner0 = owner0.owner Some(owner0) - } def actualOwner(using Quotes)(owner: quotes.reflect.Symbol): Option[quotes.reflect.Symbol] = findOwner(owner, owner0 => Util.isSynthetic(owner0) || Util.getName(owner0) == "ev") diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala index c06c2198..183bbe9f 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala @@ -6,14 +6,12 @@ object Util: def isSynthetic(using Quotes)(s: quotes.reflect.Symbol) = isSyntheticAlt(s) - def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = { + def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = import quotes.reflect.* s.flags.is(Flags.Synthetic) || s.isClassConstructor || s.isLocalDummy || isScala2Macro(s) || s.name.startsWith("x$proxy") - } - def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = { + def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = import quotes.reflect.* (s.flags.is(Flags.Macro) && s.owner.flags.is(Flags.Scala2x)) || (s.flags.is(Flags.Macro) && !s.flags.is(Flags.Inline)) - } def isSyntheticName(name: String) = name == "" || (name.startsWith("")) || name == "$anonfun" || name == "macro" def getName(using Quotes)(s: quotes.reflect.Symbol) = diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticsE2eTest.scala similarity index 100% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticsE2eTest.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsE2eTest.scala similarity index 100% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsE2eTest.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructE2eTest.scala similarity index 100% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructE2eTest.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GSeqTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GseqE2eTest.scala similarity index 97% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GSeqTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GseqE2eTest.scala index 8319614a..8b70999e 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GSeqTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GseqE2eTest.scala @@ -43,8 +43,7 @@ class GseqE2eTest extends munit.FunSuite: List .iterate(n, 10)(_ + 1) .takeWhile(_ <= 200) - .filter(_ % 2 == 0) - .size + .count(_ % 2 == 0) result .zip(expected) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala index 20aea0e4..7cf9a544 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala @@ -10,21 +10,19 @@ import java.io.File import javax.imageio.ImageIO object ImageTests: - def assertImagesEquals(result: File, expected: File) = { + def assertImagesEquals(result: File, expected: File) = val expectedImage = ImageIO.read(expected) val resultImage = ImageIO.read(result) // println("Got image:") // println(renderAsText(resultImage, 50, 50)) assertEquals(expectedImage.getWidth, resultImage.getWidth, "Width was different") assertEquals(expectedImage.getHeight, resultImage.getHeight, "Height was different") - for { + for x <- 0 until expectedImage.getWidth y <- 0 until expectedImage.getHeight - } { + do val equal = expectedImage.getRGB(x, y) == resultImage.getRGB(x, y) assert(equal, s"Pixel $x, $y was different. Output file: ${result.getAbsolutePath}") - } - } def renderAsText(bufferedImage: BufferedImage, w: Int, h: Int) = val downscaled = bufferedImage.getScaledInstance(w, h, Image.SCALE_SMOOTH) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenE2eTest.scala similarity index 100% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenE2eTest.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala index ec607279..7431960d 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala @@ -21,7 +21,7 @@ import scala.concurrent.ExecutionContext.Implicits class JuliaSet extends FunSuite: given ExecutionContext = Implicits.global - def runJuliaSet(referenceImgName: String)(using GContext): Unit = { + def runJuliaSet(referenceImgName: String)(using GContext): Unit = val dim = 4096 val max = 1 val RECURSION_LIMIT = 1000 @@ -70,7 +70,6 @@ class JuliaSet extends FunSuite: ImageUtility.renderToImage(r, dim, outputTemp.toPath) val referenceImage = getClass.getResource(referenceImgName) ImageTests.assertImagesEquals(outputTemp, new File(referenceImage.getPath)) - } test("Render julia set"): given GContext = new GContext diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala index 7cc5dc1b..200e16a6 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala @@ -3,15 +3,12 @@ package io.computenode.samples.cyfra.foton import io.computenode.cyfra import io.computenode.cyfra.* import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters -import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.library.Color.{InterpolationThemes, interpolate} import io.computenode.cyfra.dsl.library.Math3D.* -import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters import io.computenode.cyfra.foton.animation.AnimationFunctions.* +import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} import java.nio.file.Paths import scala.concurrent.duration.DurationInt diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala index fa6393a1..bd2c65de 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala @@ -1,16 +1,13 @@ package io.computenode.samples.cyfra.foton +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.library.Color.hex import io.computenode.cyfra.foton.* import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere} import io.computenode.cyfra.foton.rt.{Camera, Material} import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.library.Color.hex -import io.computenode.cyfra.dsl.given import java.nio.file.Paths import scala.concurrent.duration.DurationInt diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala index e498ce3e..b9c2279d 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala @@ -1,23 +1,17 @@ package io.computenode.samples.cyfra.oldsamples -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.runtime.* import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +import java.nio.file.Paths +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.concurrent.ExecutionContext.Implicits given GContext = new GContext() given ExecutionContext = Implicits.global @@ -44,15 +38,13 @@ def main = def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = { + def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - } - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = { + def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - } def ACESFilm(x: Vec3[Float32]): Vec3[Float32] = val a = 2.51f @@ -132,7 +124,8 @@ def main = dist < radiusA + radiusB val existingSpheres = mutable.Set.empty[((Float, Float, Float), Float)] - def randomSphere(iter: Int = 0): Sphere = { + @tailrec + def randomSphere(iter: Int = 0): Sphere = if iter > 1000 then throw new Exception("Could not find a non-intersecting sphere") def nextFloatAny = rd.nextFloat() * 2f - 1f @@ -141,7 +134,7 @@ def main = val center = (nextFloatAny * 10, nextFloatAny * 10, nextFloatPos * 10 + 8f) val radius = nextFloatPos + 1.5f if existingSpheres.exists(s => scalaTwoSpheresIntersect(s._1, s._2, center, radius)) then randomSphere(iter + 1) - else { + else existingSpheres.add((center, radius)) def color = (nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f) val emissive = (0f, 0f, 0f) @@ -158,19 +151,16 @@ def main = 0.1f, (nextFloatPos, nextFloatPos, nextFloatPos), ) - } - } def randomSpheres(n: Int) = List.fill(n)(randomSphere()) - val flash = { // flash + val flash = // flash val x = -10f val mX = -5f val y = -10f val mY = 0f val z = -5f Sphere((-7.5f, -12f, -5f), 3f, (1f, 1f, 1f), (20f, 20f, 20f)) - } val spheres = (flash :: randomSpheres(20)).map(sp => sp.copy(center = sp.center + sceneTranslation.xyz)) val walls = List( @@ -243,21 +233,19 @@ def main = def function(): GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): case (RaytracingIteration(frame), (xi: Int32, yi: Int32), lastFrame) => - def wangHash(seed: UInt32): UInt32 = { + def wangHash(seed: UInt32): UInt32 = val s1 = (seed ^ 61) ^ (seed >> 16) val s2 = s1 * 9 val s3 = s2 ^ (s2 >> 4) val s4 = s3 * 0x27d4eb2d s4 ^ (s4 >> 15) - } - def randomFloat(seed: UInt32): Random[Float32] = { + def randomFloat(seed: UInt32): Random[Float32] = val nextSeed = wangHash(seed) val f = nextSeed.asFloat / 4294967296.0f Random(f, nextSeed) - } - def randomVector(seed: UInt32): Random[Vec3[Float32]] = { + def randomVector(seed: UInt32): Random[Vec3[Float32]] = val Random(z, seed1) = randomFloat(seed) val z2 = z * 2.0f - 1.0f val Random(a, seed2) = randomFloat(seed1) @@ -266,15 +254,17 @@ def main = val x = r * cos(a2) val y = r * sin(a2) Random((x, y, z2), seed2) - } def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { - Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + val fixedQuad = + when((normal dot rayDir) > 0f): + Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) + .otherwise: + quad + val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -286,14 +276,15 @@ def main = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { - (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { - (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { - (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + val dist = + when(abs(rayDir.x) > 0.1f): + (intersectPoint.x - rayPos.x) / rayDir.x + .elseWhen(abs(rayDir.y) > 0.1f): + (intersectPoint.y - rayPos.y) / rayDir.y + .otherwise: + (intersectPoint.z - rayPos.z) / rayDir.z + + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo( dist, fixedNormal, @@ -307,24 +298,26 @@ def main = quad.refractionRoughness, quad.refractionColor, ) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -332,24 +325,24 @@ def main = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = val toRay = rayPos - sphere.center val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { - val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f))) + when(dist > minRayHitTime && dist < currentHit.dist): + val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f)) RayHitInfo( dist, normal, @@ -364,9 +357,10 @@ def main = sphere.refractionColor, fromInside, ) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testScene(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = @@ -384,22 +378,20 @@ def main = def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) val cosX = -(normal dot incident) - when(n1 > n2) { + when(n1 > n2): val n = n1 / n2 val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f) { + when(sinT2 > 1f): f90 - } otherwise { + .otherwise: val cosX2 = sqrt(1.0f - sinT2) val x = 1.0f - cosX2 val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } - } otherwise { + .otherwise: val x = 1.0f - cosX val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } val MaxBounces = 8 def getColorForRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], initRngState: UInt32): RayTraceState = @@ -407,96 +399,94 @@ def main = GSeq .gen[RayTraceState]( first = initState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngState, _) => - - val noHit = RayHitInfo(superFar, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) - val testResult = testScene(rayPos, rayDir, noHit) - when(testResult.dist < superFar) { - - val throughput2 = when(testResult.fromInside) { - throughput mulV exp[Vec3[Float32]](-testResult.refractionColor * testResult.dist) - }.otherwise { - throughput - } - - val specularChance = when(testResult.percentSpecular > 0.0f) { - fresnelReflectAmount( - when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), - when(!testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), - rayDir, - testResult.normal, - testResult.percentSpecular, - 1.0f, - ) - }.otherwise { - 0f - } - - val refractionChance = when(specularChance > 0.0f) { - testResult.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.percentSpecular)) - } otherwise testResult.refractionChance - - val Random(rayRoll, nextRngState1) = randomFloat(rngState) - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) { - 1.0f - }.otherwise(0.0f) - - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) { - 1.0f - }.otherwise(0.0f) - - val rayProbability = when(doSpecular === 1.0f) { - specularChance - }.elseWhen(doRefraction === 1.0f) { - refractionChance - }.otherwise { - 1.0f - (specularChance + refractionChance) - } - - val rayProbabilityCorrected = max(rayProbability, 0.01f) - - val nextRayPos = when(doRefraction === 1.0f) { - (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - }.otherwise { - (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - } - - val Random(randomVec1, nextRngState2) = randomVector(nextRngState1) - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.roughness * testResult.roughness)) - - val Random(randomVec2, nextRngState3) = randomVector(nextRngState2) - val refractionRayDirPerfect = - refract( - rayDir, - testResult.normal, - when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f / testResult.indexOfRefraction), - ) - val refractionRayDir = - normalize( - mix( - refractionRayDirPerfect, - normalize(-testResult.normal + randomVec2), - testResult.refractionRoughness * testResult.refractionRoughness, - ), - ) - - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - - val nextColor = (throughput2 mulV testResult.emissive) addV color - - val nextThroughput = when(doRefraction === 0.0f) { - throughput2 mulV mix[Vec3[Float32]](testResult.albedo, testResult.specularColor, doSpecular); - }.otherwise(throughput2) - - val throughputRayProb = nextThroughput * (1.0f / rayProbabilityCorrected) - - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, nextRngState3) - } otherwise RayTraceState(rayPos, rayDir, color, throughput, rngState, true) - - }, + next = + case state @ RayTraceState(rayPos, rayDir, color, throughput, rngState, _) => + val noHit = RayHitInfo(superFar, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) + val testResult = testScene(rayPos, rayDir, noHit) + when(testResult.dist < superFar): + val throughput2 = when(testResult.fromInside): + throughput mulV exp[Vec3[Float32]](-testResult.refractionColor * testResult.dist) + .otherwise: + throughput + + val specularChance = when(testResult.percentSpecular > 0.0f): + fresnelReflectAmount( + when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), + when(!testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), + rayDir, + testResult.normal, + testResult.percentSpecular, + 1.0f, + ) + .otherwise: + 0f + + val refractionChance = when(specularChance > 0.0f): + testResult.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.percentSpecular)) + .otherwise: + testResult.refractionChance + + val Random(rayRoll, nextRngState1) = randomFloat(rngState) + val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance): + 1.0f + .otherwise: + 0.0f + + val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance): + 1.0f + .otherwise: + 0.0f + + val rayProbability = when(doSpecular === 1.0f): + specularChance + .elseWhen(doRefraction === 1.0f): + refractionChance + .otherwise: + 1.0f - (specularChance + refractionChance) + + val rayProbabilityCorrected = max(rayProbability, 0.01f) + + val nextRayPos = when(doRefraction === 1.0f): + (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) + .otherwise: + (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) + + val Random(randomVec1, nextRngState2) = randomVector(nextRngState1) + val diffuseRayDir = normalize(testResult.normal + randomVec1) + val specularRayDirPerfect = reflect(rayDir, testResult.normal) + val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.roughness * testResult.roughness)) + + val Random(randomVec2, nextRngState3) = randomVector(nextRngState2) + val refractionRayDirPerfect = + refract( + rayDir, + testResult.normal, + when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f / testResult.indexOfRefraction), + ) + val refractionRayDir = + normalize( + mix( + refractionRayDirPerfect, + normalize(-testResult.normal + randomVec2), + testResult.refractionRoughness * testResult.refractionRoughness, + ), + ) + + val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) + val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) + + val nextColor = (throughput2 mulV testResult.emissive) addV color + + val nextThroughput = when(doRefraction === 0.0f): + throughput2 mulV mix[Vec3[Float32]](testResult.albedo, testResult.specularColor, doSpecular) + .otherwise: + throughput2 + + val throughputRayProb = nextThroughput * (1.0f / rayProbabilityCorrected) + + RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, nextRngState3) + .otherwise: + RayTraceState(rayPos, rayDir, color, throughput, rngState, true), ) .limit(MaxBounces) .takeWhile(!_.finished) @@ -528,9 +518,10 @@ def main = .limit(pixelIterationsPerFrame) .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) - when(frame === 0) { + when(frame === 0): (color, 1.0f) - } otherwise mix(lastFrame.at(xi, yi), (color, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) + .otherwise: + mix(lastFrame.at(xi, yi), (color, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) val initialMem = Array.fill(dim * dim)((0.5f, 0.5f, 0.5f, 0.5f)) val renders = 100 diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala index 56d9dd11..518687d3 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala @@ -1,19 +1,13 @@ package io.computenode.samples.cyfra.slides -import io.computenode.cyfra.given - -import scala.concurrent.Await -import scala.concurrent.duration.given -import io.computenode.cyfra.given +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.runtime.mem.FloatMem given GContext = new GContext() @main -def sample = +def sample() = val gpuFunction = GFunction: (value: Float32) => value * 2f diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala index c08855e9..6575bbaa 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala @@ -1,26 +1,16 @@ package io.computenode.samples.cyfra.slides -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.runtime.* import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +import java.nio.file.Paths @main -def simpleray = +def simpleRay() = val dim = 1024 val fovDeg = 60 @@ -32,11 +22,10 @@ def simpleray = val toRay = rayPos - sphereCenter val b = toRay dot rayDirection val c = (toRay dot toRay) - (sphereRadius * sphereRadius) - when((c < 0f || b < 0f) && b * b - c > 0f) { + when((c < 0f || b < 0f) && b * b - c > 0f): (1f, 1f, 1f, 1f) - } otherwise { + .otherwise: (0f, 0f, 0f, 1f) - } val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): case (_, (xi: Int32, yi: Int32), _) => diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala index 70fdd546..7784be27 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala @@ -1,29 +1,18 @@ package io.computenode.samples.cyfra.slides import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.struct.GStruct.Empty - -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given import io.computenode.cyfra.runtime.* import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +import java.nio.file.Paths @main -def rays = +def rays() = val raysPerPixel = 10 val dim = 1024 val fovDeg = 60 @@ -49,26 +38,28 @@ def rays = val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): val normal = normalize(rayPos + rayDir * dist - sphere.center) RayHitInfo(dist, normal, sphere.color, sphere.emissive) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -80,33 +71,35 @@ def rays = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -114,8 +107,8 @@ def rays = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit val sphere = Sphere(center = (1.5f, 1.5f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (3f, 3f, 3f)) diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala index 226ffab1..4ecd8e8b 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala @@ -1,32 +1,30 @@ package io.computenode.samples.cyfra.slides -import java.nio.file.Paths -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.runtime.* import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -def wangHash(seed: UInt32): UInt32 = { +import java.nio.file.Paths + +def wangHash(seed: UInt32): UInt32 = val s1 = (seed ^ 61) ^ (seed >> 16) val s2 = s1 * 9 val s3 = s2 ^ (s2 >> 4) val s4 = s3 * 0x27d4eb2d s4 ^ (s4 >> 15) -} case class Random[T <: Value](value: T, nextSeed: UInt32) -def randomFloat(seed: UInt32): Random[Float32] = { +def randomFloat(seed: UInt32): Random[Float32] = val nextSeed = wangHash(seed) val f = nextSeed.asFloat / 4294967296.0f Random(f, nextSeed) -} -def randomVector(seed: UInt32): Random[Vec3[Float32]] = { +def randomVector(seed: UInt32): Random[Vec3[Float32]] = val Random(z, seed1) = randomFloat(seed) val z2 = z * 2.0f - 1.0f val Random(a, seed2) = randomFloat(seed1) @@ -35,10 +33,9 @@ def randomVector(seed: UInt32): Random[Vec3[Float32]] = { val x = r * cos(a2) val y = r * sin(a2) Random((x, y, z2), seed2) -} @main -def randomRays = +def randomRays() = val raysPerPixel = 10 val dim = 1024 val fovDeg = 80 @@ -71,26 +68,28 @@ def randomRays = val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): val normal = normalize(rayPos + rayDir * dist - sphere.center) RayHitInfo(dist, normal, sphere.color, sphere.emissive) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -102,33 +101,34 @@ def randomRays = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -136,8 +136,8 @@ def randomRays = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f)) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala index 2212e11b..e6772e07 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala @@ -1,23 +1,11 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.collections.GArray2D import io.computenode.cyfra.foton.animation.AnimatedFunction.FunctionArguments import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.{*, given} - -import java.nio.file.{Path, Paths} -import scala.annotation.targetName -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt case class AnimatedFunction(fn: FunctionArguments => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds) extends AnimationRenderer.Scene diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala index 61b05aee..d8d6dff5 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala @@ -1,26 +1,17 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra +import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.{AnimationIteration, RenderFn} import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.runtime.{GContext, GFunction, UniformContext} -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed import io.computenode.cyfra.runtime.mem.GMem.fRGBA import io.computenode.cyfra.runtime.mem.Vec4FloatMem +import io.computenode.cyfra.runtime.{GContext, GFunction, UniformContext} -import java.nio.file.{Path, Paths} +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.{Await, ExecutionContext} -import scala.concurrent.duration.DurationInt class AnimatedFunctionRenderer(params: AnimatedFunctionRenderer.Parameters) extends AnimationRenderer[AnimatedFunction, AnimatedFunctionRenderer.RenderFn](params): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala index 28c92809..e1aa34e4 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala @@ -1,15 +1,10 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.given import io.computenode.cyfra -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.* import io.computenode.cyfra.dsl.Value.Float32 -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.RtRenderer object AnimationFunctions: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala index 72d0f032..df4ea7ca 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala @@ -2,21 +2,14 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.foton.rt.animation.AnimatedScene +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.runtime.GFunction +import io.computenode.cyfra.runtime.mem.GMem.fRGBA +import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.{*, given} -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import java.nio.file.{Path, Paths} -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt +import java.nio.file.Path trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters): @@ -25,7 +18,7 @@ trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Flo def renderFramesToDir(scene: S, destinationPath: Path): Unit = destinationPath.toFile.mkdirs() val images = renderFrames(scene) - val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt + val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt val requiredDigits = Math.ceil(Math.log10(totalFrames)).toInt images.zipWithIndex.foreach: case (image, i) => @@ -35,7 +28,7 @@ trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Flo def renderFrames(scene: S): LazyList[Array[fRGBA]] = val function = renderFunction(scene) - val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt + val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt val timestamps = LazyList.range(0, totalFrames).map(_ * msPerFrame) timestamps.zipWithIndex.map { case (time, frame) => timed(s"Animated frame $frame/$totalFrames"): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala index e65f94f1..6a314eb5 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala @@ -1,25 +1,18 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra -import ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.ImageRtRenderer import io.computenode.cyfra.* import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere} -import io.computenode.cyfra.runtime.{GFunction, UniformContext} -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration +import io.computenode.cyfra.runtime.mem.GMem.fRGBA +import io.computenode.cyfra.runtime.mem.Vec4FloatMem +import io.computenode.cyfra.runtime.{GFunction, UniformContext} +import io.computenode.cyfra.utility.ImageUtility +import io.computenode.cyfra.utility.Utility.timed -import java.nio.file.{Path, Paths} -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} +import java.nio.file.Path class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(params): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala index b35b37b2..1c591a71 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala @@ -1,25 +1,19 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere} +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.{GArray2D, GSeq} +import io.computenode.cyfra.dsl.control.Pure.pure import io.computenode.cyfra.dsl.library.Color.* import io.computenode.cyfra.dsl.library.Math3D.* +import io.computenode.cyfra.dsl.library.Random +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.runtime.GContext -import io.computenode.cyfra.dsl.control.Pure.pure -import java.nio.file.{Path, Paths} -import scala.collection.mutable +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.collections.{GArray2D, GSeq} -import io.computenode.cyfra.dsl.library.Random -import io.computenode.cyfra.dsl.struct.GStruct class RtRenderer(params: RtRenderer.Parameters): @@ -37,14 +31,13 @@ class RtRenderer(params: RtRenderer.Parameters): ) extends GStruct[RayTraceState] private def applyRefractionThroughput(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.fromInside) { + when(testResult.fromInside): state.throughput mulV exp[Vec3[Float32]](-testResult.material.refractionColor * testResult.dist) - }.otherwise { + .otherwise: state.throughput - } private def calculateSpecularChance(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.material.percentSpecular > 0.0f) { + when(testResult.material.percentSpecular > 0.0f): val material = testResult.material fresnelReflectAmount( when(testResult.fromInside)(material.indexOfRefraction).otherwise(1.0f), @@ -54,44 +47,42 @@ class RtRenderer(params: RtRenderer.Parameters): material.percentSpecular, 1.0f, ) - }.otherwise { + .otherwise: 0f - } private def getRefractionChance(state: RayTraceState, testResult: RayHitInfo, specularChance: Float32) = pure: - when(specularChance > 0.0f) { + when(specularChance > 0.0f): testResult.material.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.material.percentSpecular)) - } otherwise testResult.material.refractionChance + .otherwise: + testResult.material.refractionChance private case class RayAction(doSpecular: Float32, doRefraction: Float32, rayProbability: Float32) private def getRayAction(state: RayTraceState, testResult: RayHitInfo, random: Random): (RayAction, Random) = val specularChance = calculateSpecularChance(state, testResult) val refractionChance = getRefractionChance(state, testResult, specularChance) val (nextRandom, rayRoll) = random.next[Float32] - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) { + val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance): 1.0f - }.otherwise(0.0f) - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) { + .otherwise(0.0f) + val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance): 1.0f - }.otherwise(0.0f) + .otherwise(0.0f) - val rayProbability = when(doSpecular === 1.0f) { + val rayProbability = when(doSpecular === 1.0f): specularChance - }.elseWhen(doRefraction === 1.0f) { + .elseWhen(doRefraction === 1.0f): refractionChance - }.otherwise { + .otherwise: 1.0f - (specularChance + refractionChance) - } (RayAction(doSpecular, doRefraction, max(rayProbability, 0.01f)), nextRandom) private val rayPosNormalNudge = 0.01f private def getNextRayPos(rayPos: Vec3[Float32], rayDir: Vec3[Float32], testResult: RayHitInfo, doRefraction: Float32) = pure: - when(doRefraction =~= 1.0f) { + when(doRefraction =~= 1.0f): (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - }.otherwise { + .otherwise: (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - } private def getRefractionRayDir(rayDir: Vec3[Float32], testResult: RayHitInfo, random: Random) = val (random2, randomVec) = random.next[Vec3[Float32]] @@ -117,9 +108,10 @@ class RtRenderer(params: RtRenderer.Parameters): rayProbability: Float32, refractedThroughput: Vec3[Float32], ) = pure: - val nextThroughput = when(doRefraction === 0.0f) { - refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular); - }.otherwise(refractedThroughput) + val nextThroughput = when(doRefraction === 0.0f): + refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular) + .otherwise: + refractedThroughput nextThroughput * (1.0f / rayProbability) private def bounceRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], random: Random, scene: Scene): RayTraceState = @@ -127,35 +119,35 @@ class RtRenderer(params: RtRenderer.Parameters): GSeq .gen[RayTraceState]( first = initState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => - - val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) - val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) + next = + case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => + val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) + val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) - when(testResult.dist < params.superFar) { - val refractedThroughput = applyRefractionThroughput(state, testResult) + when(testResult.dist < params.superFar): + val refractedThroughput = applyRefractionThroughput(state, testResult) - val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) + val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) - val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) + val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) - val (random3, randomVec1) = random2.next[Vec3[Float32]] - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) + val (random3, randomVec1) = random2.next[Vec3[Float32]] + val diffuseRayDir = normalize(testResult.normal + randomVec1) + val specularRayDirPerfect = reflect(rayDir, testResult.normal) + val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) - val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) + val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) + val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) + val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color + val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color - val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) + val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) - } otherwise RayTraceState(rayPos, rayDir, color, throughput, random, true) - }, + RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) + .otherwise: + RayTraceState(rayPos, rayDir, color, throughput, random, true), ) .limit(params.maxBounces) .takeWhile(!_.finished) @@ -190,9 +182,10 @@ class RtRenderer(params: RtRenderer.Parameters): val colorCorrected = linearToSRGB(color) - when(frame === 0) { + when(frame === 0): (colorCorrected, 1.0f) - } otherwise mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) + .otherwise: + mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) object RtRenderer: trait Parameters: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala index 46d47300..ef7a03b6 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala @@ -3,8 +3,7 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra.dsl.Value.{Float32, Vec3} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.foton.rt.shapes.{Shape, ShapeCollection} -import io.computenode.cyfra.{*, given} -import izumi.reflect.Tag +import io.computenode.cyfra.given import scala.util.chaining.* diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala index 5367f800..2674c237 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala @@ -2,22 +2,14 @@ package io.computenode.cyfra.foton.rt.animation import io.computenode.cyfra import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.runtime.{GFunction, UniformContext} +import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.given - -import java.nio.file.{Path, Paths} -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt +import io.computenode.cyfra.runtime.{GFunction, UniformContext} class AnimationRtRenderer(params: AnimationRtRenderer.Parameters) extends RtRenderer(params) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala index 58ca8bf5..fe980b9e 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala @@ -30,16 +30,14 @@ object Box: val tEnter = max(tMinX, tMinY, tMinZ) val tExit = min(tMaxX, tMaxY, tMaxZ) - when(tEnter < tExit || tExit < 0.0f) { + when(tEnter < tExit || tExit < 0.0f): currentHit - } otherwise { + .otherwise: val hitDistance = when(tEnter > 0f)(tEnter).otherwise(tExit) - val hitNormal = when(tEnter =~= tMinX) { + val hitNormal = when(tEnter =~= tMinX): (when(rayDir.x > 0f)(-1f).otherwise(1f), 0f, 0f) - }.elseWhen(tEnter =~= tMinY) { + .elseWhen(tEnter =~= tMinY): (0f, when(rayDir.y > 0f)(-1f).otherwise(1f), 0f) - }.otherwise { + .otherwise: (0f, 0f, when(rayDir.z > 0f)(-1f).otherwise(1f)) - } RayHitInfo(hitDistance, hitNormal, box.material) - } diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala index b0188750..fd9e3eee 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala @@ -16,14 +16,13 @@ object Plane: def testRay(plane: Plane, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: val denom = plane.normal dot rayDir given epsilon: Float32 = 0.1f - when(denom =~= 0.0f) { + when(denom =~= 0.0f): currentHit - } otherwise { + .otherwise: val t = ((plane.point - rayPos) dot plane.normal) / denom - when(t < 0.0f || t >= currentHit.dist) { + when(t < 0.0f || t >= currentHit.dist): currentHit - } otherwise { - val hitNormal = when(denom < 0.0f)(plane.normal).otherwise(-plane.normal) + .otherwise: + val hitNormal = when(denom < 0.0f)(plane.normal) + .otherwise(-plane.normal) RayHitInfo(t, hitNormal, plane.material) - } - } diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala index fbdcc1b7..58b2d641 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala @@ -21,9 +21,10 @@ object Quad: given TestRay[Quad] with def testRay(quad: Quad, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.material) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -35,33 +36,34 @@ object Quad: val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > MinRayHitTime && dist < currentHit.dist) { + when(dist > MinRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.material) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -69,5 +71,5 @@ object Quad: val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala index 9fb3f891..24af9919 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.library.Functions.* import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo trait Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala index e80ca655..efe2c76a 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala @@ -1,15 +1,15 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.shapes.* -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import izumi.reflect.Tag -import io.computenode.cyfra.dsl.library.Functions.* -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.dsl.library.Functions.* import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.foton.rt.Material +import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo +import io.computenode.cyfra.foton.rt.shapes.* import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay +import izumi.reflect.Tag import scala.util.chaining.* diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala index 0328780b..0e0d556c 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala @@ -1,17 +1,11 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} - -import java.nio.file.Paths -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.foton.rt.Material +import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay case class Sphere(center: Vec3[Float32], radius: Float32, material: Material) extends GStruct[Sphere] with Shape @@ -23,17 +17,18 @@ object Sphere: val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > MinRayHitTime && dist < currentHit.dist) { - val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f))) + when(dist > MinRayHitTime && dist < currentHit.dist): + val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f)) RayHitInfo(dist, normal, sphere.material, fromInside) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala index 164458bf..b72be392 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala @@ -5,6 +5,5 @@ import io.computenode.cyfra.runtime.mem.{GMem, RamGMem} import scala.concurrent.Future -trait Executable[H <: Value, R <: Value] { +trait Executable[H <: Value, R <: Value]: def execute(input: GMem[H], output: RamGMem[R, ?]): Future[Unit] -} diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala index 90f0f91c..3ebd43d9 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala @@ -29,7 +29,7 @@ class GContext(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()): def compile[G <: GStruct[G]: Tag: GStructSchema, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr]( function: GFunction[G, H, R], - ): ComputePipeline = { + ): ComputePipeline = val uniformStructSchema = summon[GStructSchema[G]] val uniformStruct = uniformStructSchema.fromTree(UniformStructRef) val tree = function.fn @@ -44,7 +44,6 @@ class GContext(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()): val shader = Shader(optimizedShaderCode, org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device) ComputePipeline(shader, vkContext) - } def execute[G <: GStruct[G]: Tag: GStructSchema, H <: Value, R <: Value](mem: GMem[H], fn: GFunction[G, H, R])(using uniformContext: UniformContext[G], diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala index 02a9e094..1c85b3fd 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala @@ -9,11 +9,10 @@ import izumi.reflect.Tag case class GFunction[G <: GStruct[G]: GStructSchema: Tag, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: (G, Int32, GArray[H]) => R)( implicit context: GContext, -) { +): def arrayInputs: List[Tag[?]] = List(summon[Tag[H]]) def arrayOutputs: List[Tag[?]] = List(summon[Tag[R]]) val pipeline: ComputePipeline = context.compile(this) -} object GFunction: def apply[H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: H => R)(using context: GContext): GFunction[GStruct.Empty, H, R] = diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala index a0c6078f..4264233d 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala @@ -4,7 +4,6 @@ import io.computenode.cyfra.dsl.Value.Float32 import org.lwjgl.BufferUtils import java.nio.ByteBuffer -import org.lwjgl.system.MemoryUtil class FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Float32, Float]: def toArray: Array[Float] = @@ -13,7 +12,7 @@ class FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Fl res.get(result) result -object FloatMem { +object FloatMem: val FloatSize = 4 def apply(floats: Array[Float]): FloatMem = @@ -26,4 +25,3 @@ object FloatMem { def apply(size: Int): FloatMem = val data = BufferUtils.createByteBuffer(size * FloatSize) new FloatMem(size, data) -} diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala index 42246671..a6efd211 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala @@ -1,14 +1,12 @@ package io.computenode.cyfra.runtime.mem -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import io.computenode.cyfra.runtime.{GContext, GFunction} import io.computenode.cyfra.dsl.struct.* +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.runtime.{GContext, GFunction, UniformContext} +import io.computenode.cyfra.spirv.SpirvTypes.typeStride import izumi.reflect.Tag import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil -import io.computenode.cyfra.runtime.UniformContext import java.nio.ByteBuffer @@ -31,9 +29,9 @@ object GMem: typeStride(t) }.sum - def serializeUniform(g: GStruct[?]): ByteBuffer = { + def serializeUniform(g: GStruct[?]): ByteBuffer = val data = BufferUtils.createByteBuffer(totalStride(g.schema)) - g.productIterator.foreach { + g.productIterator.foreach: case Int32(ConstInt32(i)) => data.putInt(i) case Float32(ConstFloat32(f)) => data.putFloat(f) case Vec4(ComposeVec4(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)), Float32(ConstFloat32(z)), Float32(ConstFloat32(a)))) => @@ -50,7 +48,5 @@ object GMem: data.putFloat(y) case illegal => throw new IllegalArgumentException(s"Uniform must be constructed from constants (got field $illegal)") - } data.rewind() data - } diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala index 2c246aab..72d12a82 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala @@ -4,7 +4,6 @@ import io.computenode.cyfra.dsl.Value.Int32 import org.lwjgl.BufferUtils import java.nio.ByteBuffer -import org.lwjgl.system.MemoryUtil class IntMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Int32, Int]: def toArray: Array[Int] = diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala index 710781ea..ff48aa6b 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala @@ -3,22 +3,20 @@ package io.computenode.cyfra.runtime.mem import io.computenode.cyfra.dsl.Value.{Float32, Vec4} import io.computenode.cyfra.runtime.mem.GMem.fRGBA import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil import java.nio.ByteBuffer class Vec4FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Vec4[Float32], fRGBA]: - def toArray: Array[fRGBA] = { + def toArray: Array[fRGBA] = val res = data.asFloatBuffer() val result = new Array[fRGBA](size) for i <- 0 until size do result(i) = (res.get(), res.get(), res.get(), res.get()) result - } object Vec4FloatMem: val Vec4FloatSize = 16 - def apply(vecs: Array[fRGBA]): Vec4FloatMem = { + def apply(vecs: Array[fRGBA]): Vec4FloatMem = val size = vecs.length val data = BufferUtils.createByteBuffer(size * Vec4FloatSize) vecs.foreach { case (x, y, z, a) => @@ -29,7 +27,6 @@ object Vec4FloatMem: } data.rewind() new Vec4FloatMem(size, data) - } def apply(size: Int): Vec4FloatMem = val data = BufferUtils.createByteBuffer(size * Vec4FloatSize) diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala index e845adde..f0c7c38f 100644 --- a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala @@ -4,8 +4,6 @@ import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogge import io.computenode.cyfra.utility.Logger.logger import java.nio.ByteBuffer -import java.nio.charset.StandardCharsets -import java.nio.file.Files object SpirvDisassembler extends SpirvTool("spirv-dis"): diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala index 6a21f550..5ce14f10 100644 --- a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala @@ -3,17 +3,14 @@ package io.computenode.cyfra.spirvtools import io.computenode.cyfra.spirvtools.SpirvCross.Enable import munit.FunSuite -class SpirvCrossTest extends FunSuite { +class SpirvCrossTest extends FunSuite: - test("SPIR-V cross compilation succeeded") { + test("SPIR-V cross compilation succeeded"): val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") - val glslShader = SpirvCross.crossCompileSpirv(shaderCode, crossCompilation = Enable(throwOnFail = true)) match { + val glslShader = SpirvCross.crossCompileSpirv(shaderCode, crossCompilation = Enable(throwOnFail = true)) match case None => fail("Failed to disassemble shader.") case Some(assembly) => assembly - } val referenceGlsl = SpirvTestUtils.loadResourceAsString("optimized.glsl") assertEquals(glslShader, referenceGlsl) - } -} diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala index 0138b0cd..1918285b 100644 --- a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala @@ -3,17 +3,14 @@ package io.computenode.cyfra.spirvtools import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable import munit.FunSuite -class SpirvDisassemblerTest extends FunSuite { +class SpirvDisassemblerTest extends FunSuite: - test("SPIR-V disassembly succeeded") { + test("SPIR-V disassembly succeeded"): val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") - val assembly = SpirvDisassembler.disassembleSpirv(shaderCode, disassembly = Enable(throwOnFail = true)) match { + val assembly = SpirvDisassembler.disassembleSpirv(shaderCode, disassembly = Enable(throwOnFail = true)) match case None => fail("Failed to disassemble shader.") case Some(assembly) => assembly - } val referenceAssembly = SpirvTestUtils.loadResourceAsString("optimized.spvasm") assertEquals(assembly, referenceAssembly) - } -} diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala index 5557cfe0..a10925f8 100644 --- a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala @@ -6,19 +6,16 @@ import munit.FunSuite import java.nio.ByteBuffer -class SpirvOptimizerTest extends FunSuite { +class SpirvOptimizerTest extends FunSuite: - test("SPIR-V optimization succeeded") { + test("SPIR-V optimization succeeded"): val shaderCode = SpirvTestUtils.loadShaderFromResources("original.spv") - val optimizedShaderCode = SpirvOptimizer.optimizeSpirv(shaderCode, SpirvOptimizer.Enable(throwOnFail = true, settings = Seq(Param("-O")))) match { + val optimizedShaderCode = SpirvOptimizer.optimizeSpirv(shaderCode, SpirvOptimizer.Enable(throwOnFail = true, settings = Seq(Param("-O")))) match case None => fail("Failed to optimize shader code.") case Some(optimizedShaderCode) => optimizedShaderCode - } val optimizedAssembly = SpirvDisassembler.disassembleSpirv(optimizedShaderCode, disassembly = Enable(throwOnFail = true)) val referenceOptimizedShaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") val referenceAssembly = SpirvDisassembler.disassembleSpirv(referenceOptimizedShaderCode, disassembly = Enable(throwOnFail = true)) assertEquals(optimizedAssembly, referenceAssembly) - } -} diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala index ccb74760..201b6f73 100644 --- a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala @@ -4,21 +4,19 @@ import java.nio.ByteBuffer import java.nio.file.{Files, Paths} import scala.io.Source -object SpirvTestUtils { - def loadShaderFromResources(path: String): ByteBuffer = { +object SpirvTestUtils: + def loadShaderFromResources(path: String): ByteBuffer = val resourceUrl = getClass.getClassLoader.getResource(path) require(resourceUrl != null, s"Resource not found: $path") val bytes = Files.readAllBytes(Paths.get(resourceUrl.toURI)) ByteBuffer.wrap(bytes) - } - def loadResourceAsString(path: String): String = { + def loadResourceAsString(path: String): String = val source = Source.fromResource(path) try source.mkString finally source.close() - } - def corruptMagicNumber(original: ByteBuffer): ByteBuffer = { + def corruptMagicNumber(original: ByteBuffer): ByteBuffer = val corrupted = ByteBuffer.allocate(original.capacity()) original.rewind() corrupted.put(original) @@ -27,5 +25,3 @@ object SpirvTestUtils { corrupted.rewind() corrupted - } -} diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala index b819aec0..2fd2b8c5 100644 --- a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala @@ -6,17 +6,16 @@ import java.io.{ByteArrayOutputStream, File} import java.nio.ByteBuffer import java.nio.file.Files -class SpirvToolTest extends FunSuite { +class SpirvToolTest extends FunSuite: private def isWindows: Boolean = System.getProperty("os.name").toLowerCase.contains("win") - class TestSpirvTool(toolName: String) extends SpirvTool(toolName) { + class TestSpirvTool(toolName: String) extends SpirvTool(toolName): def runExecuteCmd(input: ByteBuffer, cmd: Seq[String]): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] = executeSpirvCmd(input, cmd) - } if !isWindows then - test("executeSpirvCmd returns exit code and output streams on valid command") { + test("executeSpirvCmd returns exit code and output streams on valid command"): val tool = new TestSpirvTool("cat") val inputBytes = "hello SPIR-V".getBytes("UTF-8") @@ -33,9 +32,8 @@ class SpirvToolTest extends FunSuite { assertEquals(exitCode, 0) assert(outputString == "hello SPIR-V") assertEquals(errStream.size(), 0) - } - test("executeSpirvCmd returns non-zero exit code on invalid command") { + test("executeSpirvCmd returns non-zero exit code on invalid command"): val tool = new TestSpirvTool("invalid-cmd") val byteBuffer = ByteBuffer.wrap("".getBytes("UTF-8")) @@ -45,9 +43,8 @@ class SpirvToolTest extends FunSuite { assert(result.isLeft) val error = result.left.getOrElse(fail("Should have error")) assert(error.getMessage.contains("Failed to execute SPIR-V command")) - } - test("dumpSpvToFile writes ByteBuffer content to file") { + test("dumpSpvToFile writes ByteBuffer content to file"): val tmpFile = Files.createTempFile("spirv-dump-test", ".spv") val data = "SPIRV binary data".getBytes("UTF-8") @@ -62,10 +59,7 @@ class SpirvToolTest extends FunSuite { assert(buffer.position() == 0) Files.deleteIfExists(tmpFile) - } - test("Param.asStringParam returns correct string") { + test("Param.asStringParam returns correct string"): val param = SpirvTool.Param("test-value") assertEquals(param.asStringParam, "test-value") - } -} diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala index df49b5c3..a857e5fb 100644 --- a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala @@ -3,31 +3,26 @@ package io.computenode.cyfra.spirvtools import io.computenode.cyfra.spirvtools.SpirvValidator.Enable import munit.FunSuite -class SpirvValidatorTest extends FunSuite { +class SpirvValidatorTest extends FunSuite: - test("SPIR-V validation succeeded") { + test("SPIR-V validation succeeded"): val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") - try { + try SpirvValidator.validateSpirv(shaderCode, validation = Enable(throwOnFail = true)) assert(true) - } catch { + catch case e: Throwable => fail(s"Validation unexpectedly failed: ${e.getMessage}") - } - } - test("SPIR-V validation fail") { + test("SPIR-V validation fail"): val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") val corruptedShaderCode = SpirvTestUtils.corruptMagicNumber(shaderCode) - try { + try SpirvValidator.validateSpirv(corruptedShaderCode, validation = Enable(throwOnFail = true)) fail(s"Validation was supposed to fail.") - } catch { + catch case e: Throwable => val result = e.getMessage assertEquals(result, "SPIR-V validation failed with exit code 1.\nValidation errors:\nerror: line 0: Invalid SPIR-V magic number.\n") - } - } -} diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala index 81478b63..77b8532a 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala @@ -5,20 +5,16 @@ import java.io.File import java.nio.file.Path import javax.imageio.ImageIO -object ImageUtility { +object ImageUtility: def renderToImage(arr: Array[(Float, Float, Float, Float)], n: Int, location: Path): Unit = renderToImage(arr, n, n, location) - def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = { + def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = val image = new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB) for y <- 0 until h do - for x <- 0 until w do { + for x <- 0 until w do val (r, g, b, _) = arr(y * w + x) def clip(f: Float) = Math.min(1.0f, Math.max(0.0f, f)) val (iR, iG, iB) = ((clip(r) * 255).toInt, (clip(g) * 255).toInt, (clip(b) * 255).toInt) image.setRGB(x, y, (iR << 16) | (iG << 8) | iB) - } val outputFile = location.toFile ImageIO.write(image, "png", outputFile) - } - -} diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index 31a98b8b..a081d60a 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -1,7 +1,6 @@ package io.computenode.cyfra.utility import io.computenode.cyfra.utility.Logger.logger -import org.slf4j.LoggerFactory object Utility: diff --git a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala similarity index 95% rename from cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala rename to cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala index cf994b47..f85e2fd6 100644 --- a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala +++ b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.vscode.VscodeConnection.Message import java.net.http.{HttpClient, WebSocket} -class VscodeConnection(host: String, port: Int) { +class VscodeConnection(host: String, port: Int): val ws = HttpClient .newHttpClient() .newWebSocketBuilder() @@ -13,7 +13,6 @@ class VscodeConnection(host: String, port: Int) { def send(message: Message): Unit = ws.sendText(message.toJson, true) -} object VscodeConnection: trait Message: diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala index a12e1c7e..7b3231cd 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala @@ -9,13 +9,12 @@ import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool} /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] object VulkanContext { +private[cyfra] object VulkanContext: val ValidationLayer: String = "VK_LAYER_KHRONOS_validation" val SyncLayer: String = "VK_LAYER_KHRONOS_synchronization2" private val ValidationLayers: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean -} -private[cyfra] class VulkanContext { +private[cyfra] class VulkanContext: val instance: Instance = new Instance(ValidationLayers) val debugCallback: Option[DebugCallback] = if ValidationLayers then Some(new DebugCallback(instance)) else None val device: Device = new Device(instance) @@ -27,7 +26,7 @@ private[cyfra] class VulkanContext { logger.debug("Vulkan context created") logger.debug("Running on device: " + device.physicalDeviceName) - def destroy(): Unit = { + def destroy(): Unit = commandPool.destroy() descriptorPool.destroy() allocator.destroy() @@ -35,5 +34,3 @@ private[cyfra] class VulkanContext { device.destroy() debugCallback.foreach(_.destroy()) instance.destroy() - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala index fd43daa2..e595b0ec 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala @@ -1,20 +1,16 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val createInfo = VkCommandPoolCreateInfo .calloc(stack) .sType$Default() @@ -25,12 +21,11 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends val pCommandPoll = stack.callocLong(1) check(vkCreateCommandPool(device.get, createInfo, null, pCommandPoll), "Failed to create command pool") pCommandPoll.get() - } private val commandPool = handle def beginSingleTimeCommands(): VkCommandBuffer = - pushStack { stack => + pushStack: stack => val commandBuffer = this.createCommandBuffer() val beginInfo = VkCommandBufferBeginInfo @@ -40,12 +35,11 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends check(vkBeginCommandBuffer(commandBuffer, beginInfo), "Failed to begin single time command buffer") commandBuffer - } def createCommandBuffer(): VkCommandBuffer = createCommandBuffers(1).head - def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack { stack => + def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack: stack => val allocateInfo = VkCommandBufferAllocateInfo .calloc(stack) .sType$Default() @@ -56,33 +50,27 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends val pointerBuffer = stack.callocPointer(n) check(vkAllocateCommandBuffers(device.get, allocateInfo, pointerBuffer), "Failed to allocate command buffers") 0 until n map (i => pointerBuffer.get(i)) map (new VkCommandBuffer(_, device.get)) - } def endSingleTimeCommands(commandBuffer: VkCommandBuffer): Fence = - pushStack { stack => + pushStack: stack => vkEndCommandBuffer(commandBuffer) - val pointerBuffer = stack.callocPointer(1).put(0, commandBuffer) val submitInfo = VkSubmitInfo .calloc(stack) .sType$Default() .pCommandBuffers(pointerBuffer) - val fence = new Fence(device, 0, () => freeCommandBuffer(commandBuffer)) queue.submit(submitInfo, fence) fence - } def freeCommandBuffer(commandBuffer: VkCommandBuffer*): Unit = - pushStack { stack => + pushStack: stack => val pointerBuffer = stack.callocPointer(commandBuffer.length) commandBuffer.foreach(pointerBuffer.put) pointerBuffer.flip() vkFreeCommandBuffers(device.get, commandPool, pointerBuffer) - } protected def close(): Unit = vkDestroyCommandPool(device.get, commandPool, null) protected def getFlags: Int -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala index 85c4ac96..31f16d8c 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala @@ -1,21 +1,16 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkFenceCreateInfo -import java.nio.LongBuffer -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit = () => ()) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit = () => ()) extends VulkanObjectHandle: + protected val handle: Long = pushStack(stack => val fenceInfo = VkFenceCreateInfo .calloc(stack) .sType$Default() @@ -24,33 +19,27 @@ private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit val pFence = stack.callocLong(1) check(vkCreateFence(device.get, fenceInfo, null, pFence), "Failed to create fence") - pFence.get() - } + pFence.get(), + ) - override def close(): Unit = { + override def close(): Unit = onDestroy.apply() vkDestroyFence(device.get, handle, null) - } - def isSignaled: Boolean = { + def isSignaled: Boolean = val result = vkGetFenceStatus(device.get, handle) if !(result == VK_SUCCESS || result == VK_NOT_READY) then throw new VulkanAssertionError("Failed to get fence status", result) result == VK_SUCCESS - } - def reset(): Fence = { + def reset(): Fence = vkResetFences(device.get, handle) this - } - def block(): Fence = { + def block(): Fence = block(Long.MaxValue) this - } - def block(timeout: Long): Boolean = { - val err = vkWaitForFences(device.get, handle, true, timeout); - if err != VK_SUCCESS && err != VK_TIMEOUT then throw new VulkanAssertionError("Failed to wait for fences", err); + def block(timeout: Long): Boolean = + val err = vkWaitForFences(device.get, handle, true, timeout) + if err != VK_SUCCESS && err != VK_TIMEOUT then throw new VulkanAssertionError("Failed to wait for fences", err) err == VK_SUCCESS; - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala index a6db2fe2..67621756 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala @@ -6,7 +6,5 @@ import org.lwjgl.vulkan.VK10.VK_COMMAND_POOL_CREATE_TRANSIENT_BIT /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] class OneTimeCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) { +private[cyfra] class OneTimeCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue): protected def getFlags: Int = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala index bbc5ce70..506ee37c 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala @@ -1,31 +1,23 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.pushStack import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.pushStack import io.computenode.cyfra.vulkan.util.VulkanObject -import org.lwjgl.PointerBuffer -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.{vkGetDeviceQueue, vkQueueSubmit} import org.lwjgl.vulkan.{VkQueue, VkSubmitInfo} -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject { - private val queue: VkQueue = pushStack { stack => +private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject: + private val queue: VkQueue = pushStack: stack => val pQueue = stack.callocPointer(1) vkGetDeviceQueue(device.get, familyIndex, queueIndex, pQueue) new VkQueue(pQueue.get(0), device.get) - } - def submit(submitInfo: VkSubmitInfo, fence: Fence): Int = this.synchronized { + def submit(submitInfo: VkSubmitInfo, fence: Fence): Int = this.synchronized: vkQueueSubmit(queue, submitInfo, fence.get) - } def get: VkQueue = queue protected def close(): Unit = () -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala index a1777d2a..e65b145a 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala @@ -1,29 +1,22 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkSemaphoreCreateInfo -import scala.util.Using - /** @author * MarconZet Created 30.10.2019 */ -private[cyfra] class Semaphore(device: Device) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class Semaphore(device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val semaphoreCreateInfo = VkSemaphoreCreateInfo .calloc(stack) .sType$Default() val pointer = stack.callocLong(1) check(vkCreateSemaphore(device.get, semaphoreCreateInfo, null, pointer), "Failed to create semaphore") pointer.get() - } def close(): Unit = vkDestroySemaphore(device.get, handle, null) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala index e2eb7bad..a7127f4a 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala @@ -5,6 +5,5 @@ import io.computenode.cyfra.vulkan.core.Device /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] class StandardCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) { +private[cyfra] class StandardCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue): protected def getFlags: Int = 0 -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala index ca2456de..d452b13f 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala @@ -1,24 +1,20 @@ package io.computenode.cyfra.vulkan.compute -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* -import scala.util.Using - /** @author * MarconZet Created 14.04.2020 */ -private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanContext) extends VulkanObjectHandle { +private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanContext) extends VulkanObjectHandle: private val device: Device = context.device val descriptorSetLayouts: Seq[(Long, LayoutSet)] = computeShader.layoutInfo.sets.map(x => (createDescriptorSetLayout(x), x)) - val pipelineLayout: Long = pushStack { stack => + val pipelineLayout: Long = pushStack: stack => val pipelineLayoutCreateInfo = VkPipelineLayoutCreateInfo .calloc(stack) .sType$Default() @@ -30,8 +26,8 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC val pPipelineLayout = stack.callocLong(1) check(vkCreatePipelineLayout(device.get, pipelineLayoutCreateInfo, null, pPipelineLayout), "Failed to create pipeline layout") pPipelineLayout.get(0) - } - protected val handle: Long = pushStack { stack => + + protected val handle: Long = pushStack: stack => val pipelineShaderStageCreateInfo = VkPipelineShaderStageCreateInfo .calloc(stack) .sType$Default() @@ -55,17 +51,15 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC val pPipeline = stack.callocLong(1) check(vkCreateComputePipelines(device.get, 0, computePipelineCreateInfo, null, pPipeline), "Failed to create compute pipeline") pPipeline.get(0) - } - protected def close(): Unit = { + protected def close(): Unit = vkDestroyPipeline(device.get, handle, null) vkDestroyPipelineLayout(device.get, pipelineLayout, null) descriptorSetLayouts.map(_._1).foreach(vkDestroyDescriptorSetLayout(device.get, _, null)) - } - private def createDescriptorSetLayout(set: LayoutSet): Long = pushStack { stack => + private def createDescriptorSetLayout(set: LayoutSet): Long = pushStack: stack => val descriptorSetLayoutBindings = VkDescriptorSetLayoutBinding.calloc(set.bindings.length, stack) - set.bindings.foreach { binding => + set.bindings.foreach: binding => descriptorSetLayoutBindings .get() .binding(binding.id) @@ -75,7 +69,7 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC .descriptorCount(1) .stageFlags(VK_SHADER_STAGE_COMPUTE_BIT) .pImmutableSamplers(null) - } + descriptorSetLayoutBindings.flip() val descriptorSetLayoutCreateInfo = VkDescriptorSetLayoutCreateInfo @@ -88,5 +82,3 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC val pDescriptorSetLayout = stack.callocLong(1) check(vkCreateDescriptorSetLayout(device.get, descriptorSetLayoutCreateInfo, null, pDescriptorSetLayout), "Failed to create descriptor set layout") pDescriptorSetLayout.get(0) - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala index ac3924b7..6032e37f 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala @@ -1,20 +1,16 @@ package io.computenode.cyfra.vulkan.compute -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.joml.Vector3ic -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkShaderModuleCreateInfo import java.io.{File, FileInputStream, IOException} +import java.nio.ByteBuffer import java.nio.channels.FileChannel -import java.nio.{ByteBuffer, LongBuffer} -import java.util.stream.Collectors -import java.util.{List, Objects} -import scala.util.Using +import java.util.Objects /** @author * MarconZet Created 25.04.2020 @@ -25,9 +21,9 @@ private[cyfra] class Shader( val layoutInfo: LayoutInfo, val functionName: String, device: Device, -) extends VulkanObjectHandle { +) extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val shaderModuleCreateInfo = VkShaderModuleCreateInfo .calloc(stack) .sType$Default() @@ -38,25 +34,21 @@ private[cyfra] class Shader( val pShaderModule = stack.callocLong(1) check(vkCreateShaderModule(device.get, shaderModuleCreateInfo, null, pShaderModule), "Failed to create shader module") pShaderModule.get() - } protected def close(): Unit = vkDestroyShaderModule(device.get, handle, null) -} -object Shader { +object Shader: def loadShader(path: String): ByteBuffer = loadShader(path, getClass.getClassLoader) private def loadShader(path: String, classLoader: ClassLoader): ByteBuffer = - try { + try val file = new File(Objects.requireNonNull(classLoader.getResource(path)).getFile) val fis = new FileInputStream(file) val fc = fis.getChannel fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()) - } catch + catch case e: IOException => throw new RuntimeException(e) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala index c09a1e7a..4c2c37ca 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala @@ -1,29 +1,25 @@ package io.computenode.cyfra.vulkan.core -import DebugCallback.DEBUG_REPORT import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.util.Util.check +import io.computenode.cyfra.vulkan.core.DebugCallback.DEBUG_REPORT import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} import org.lwjgl.BufferUtils import org.lwjgl.system.MemoryUtil.NULL import org.lwjgl.vulkan.EXTDebugReport.* import org.lwjgl.vulkan.VK10.VK_SUCCESS import org.lwjgl.vulkan.{VkDebugReportCallbackCreateInfoEXT, VkDebugReportCallbackEXT} -import org.slf4j.{Logger, LoggerFactory} import java.lang.Integer.highestOneBit -import java.nio.LongBuffer /** @author * MarconZet Created 13.04.2020 */ -object DebugCallback { +object DebugCallback: val DEBUG_REPORT = VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT -} -private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandle { - override protected val handle: Long = { - val debugCallback = new VkDebugReportCallbackEXT() { +private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandle: + override protected val handle: Long = + val debugCallback = new VkDebugReportCallbackEXT(): def invoke( flags: Int, objectType: Int, @@ -33,9 +29,9 @@ private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandl pLayerPrefix: Long, pMessage: Long, pUserData: Long, - ): Int = { + ): Int = val decodedMessage = VkDebugReportCallbackEXT.getString(pMessage) - highestOneBit(flags) match { + highestOneBit(flags) match case VK_DEBUG_REPORT_DEBUG_BIT_EXT => logger.debug(decodedMessage) case VK_DEBUG_REPORT_ERROR_BIT_EXT => @@ -45,17 +41,13 @@ private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandl case VK_DEBUG_REPORT_INFORMATION_BIT_EXT => logger.info(decodedMessage) case x => logger.error(s"Unexpected value: x, message: $decodedMessage") - } 0 - } - } setupDebugging(DEBUG_REPORT, debugCallback) - } override protected def close(): Unit = vkDestroyDebugReportCallbackEXT(instance.get, handle, null) - private def setupDebugging(flags: Int, callback: VkDebugReportCallbackEXT): Long = { + private def setupDebugging(flags: Int, callback: VkDebugReportCallbackEXT): Long = val dbgCreateInfo = VkDebugReportCallbackCreateInfoEXT .create() .sType$Default() @@ -68,5 +60,3 @@ private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandl val callbackHandle = pCallback.get(0) if err != VK_SUCCESS then throw new VulkanAssertionError("Failed to create DebugCallback", err) callbackHandle - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala index fe3e519c..e23ea52e 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala @@ -17,33 +17,27 @@ import scala.jdk.CollectionConverters.given * MarconZet Created 13.04.2020 */ -object Device { +object Device: final val MacOsExtension = VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME final val SyncExtension = VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME -} -private[cyfra] class Device(instance: Instance) extends VulkanObject { - - val physicalDevice: VkPhysicalDevice = pushStack { stack => +private[cyfra] class Device(instance: Instance) extends VulkanObject: + val physicalDevice: VkPhysicalDevice = pushStack: stack => val pPhysicalDeviceCount = stack.callocInt(1) check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, null), "Failed to get number of physical devices") val deviceCount = pPhysicalDeviceCount.get(0) if deviceCount == 0 then throw new AssertionError("Failed to find GPUs with Vulkan support") - val pPhysicalDevices = stack.callocPointer(deviceCount) check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, pPhysicalDevices), "Failed to get physical devices") - new VkPhysicalDevice(pPhysicalDevices.get(), instance.get) - } - val physicalDeviceName: String = pushStack { stack => + val physicalDeviceName: String = pushStack: stack => val pProperties = VkPhysicalDeviceProperties.calloc(stack) vkGetPhysicalDeviceProperties(physicalDevice, pProperties) pProperties.deviceNameString() - } - val computeQueueFamily: Int = pushStack { stack => + val computeQueueFamily: Int = pushStack: stack => val pQueueFamilyCount = stack.callocInt(1) vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, null) val queueFamilyCount = pQueueFamilyCount.get(0) @@ -64,87 +58,82 @@ private[cyfra] class Device(instance: Instance) extends VulkanObject { (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0 }) .getOrElse(throw new AssertionError("No suitable queue family found for computing")) - } - - private val device: VkDevice = - pushStack { stack => - val pPropertiesCount = stack.callocInt(1) - check( - vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, null), - "Failed to get number of properties extension", - ) - val propertiesCount = pPropertiesCount.get(0) - - val pProperties = VkExtensionProperties.calloc(propertiesCount, stack) - check( - vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties), - "Failed to get extension properties", - ) - - val deviceExtensions = pProperties.iterator().asScala.map(_.extensionNameString()) - val deviceExtensionsSet = deviceExtensions.toSet - - val vulkan12Features = VkPhysicalDeviceVulkan12Features - .calloc(stack) - .sType$Default() - - val vulkan13Features = VkPhysicalDeviceVulkan13Features - .calloc(stack) - .sType$Default() - - val physicalDeviceFeatures = VkPhysicalDeviceFeatures2 - .calloc(stack) - .sType$Default() - .pNext(vulkan12Features) - .pNext(vulkan13Features) - - vkGetPhysicalDeviceFeatures2(physicalDevice, physicalDeviceFeatures) - - val additionalExtension = pProperties.stream().anyMatch(x => x.extensionNameString().equals(MacOsExtension)) - - val pQueuePriorities = stack.callocFloat(1).put(1.0f) - pQueuePriorities.flip() - - val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack) - pQueueCreateInfo - .get(0) - .sType$Default() - .pNext(0) - .flags(0) - .queueFamilyIndex(computeQueueFamily) - .pQueuePriorities(pQueuePriorities) - - val extensions = Seq(MacOsExtension, SyncExtension).filter(deviceExtensionsSet) - val ppExtensionNames = stack.callocPointer(extensions.length) - extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension))) - ppExtensionNames.flip() - - val sync2 = VkPhysicalDeviceSynchronization2Features - .calloc(stack) - .sType$Default() - .synchronization2(true) - - val pCreateInfo = VkDeviceCreateInfo - .create() - .sType$Default() - .pNext(sync2) - .pQueueCreateInfos(pQueueCreateInfo) - .ppEnabledExtensionNames(ppExtensionNames) - - if instance.enabledLayers.contains(ValidationLayer) then { - val ppValidationLayers = stack.callocPointer(1).put(stack.ASCII(ValidationLayer)) - pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip()) - } - - assert(vulkan13Features.synchronization2() || extensions.contains(SyncExtension)) - val pDevice = stack.callocPointer(1) - check(vkCreateDevice(physicalDevice, pCreateInfo, null, pDevice), "Failed to create device") - new VkDevice(pDevice.get(0), physicalDevice, pCreateInfo) - } + private val device: VkDevice = pushStack: stack => + val pPropertiesCount = stack.callocInt(1) + check( + vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, null), + "Failed to get number of properties extension", + ) + val propertiesCount = pPropertiesCount.get(0) + + val pProperties = VkExtensionProperties.calloc(propertiesCount, stack) + check( + vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties), + "Failed to get extension properties", + ) + + val deviceExtensions = pProperties.iterator().asScala.map(_.extensionNameString()) + val deviceExtensionsSet = deviceExtensions.toSet + + val vulkan12Features = VkPhysicalDeviceVulkan12Features + .calloc(stack) + .sType$Default() + + val vulkan13Features = VkPhysicalDeviceVulkan13Features + .calloc(stack) + .sType$Default() + + val physicalDeviceFeatures = VkPhysicalDeviceFeatures2 + .calloc(stack) + .sType$Default() + .pNext(vulkan12Features) + .pNext(vulkan13Features) + + vkGetPhysicalDeviceFeatures2(physicalDevice, physicalDeviceFeatures) + + val additionalExtension = pProperties.stream().anyMatch(x => x.extensionNameString().equals(MacOsExtension)) + + val pQueuePriorities = stack.callocFloat(1).put(1.0f) + pQueuePriorities.flip() + + val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack) + pQueueCreateInfo + .get(0) + .sType$Default() + .pNext(0) + .flags(0) + .queueFamilyIndex(computeQueueFamily) + .pQueuePriorities(pQueuePriorities) + + val extensions = Seq(MacOsExtension, SyncExtension).filter(deviceExtensionsSet) + val ppExtensionNames = stack.callocPointer(extensions.length) + extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension))) + ppExtensionNames.flip() + + val sync2 = VkPhysicalDeviceSynchronization2Features + .calloc(stack) + .sType$Default() + .synchronization2(true) + + val pCreateInfo = VkDeviceCreateInfo + .create() + .sType$Default() + .pNext(sync2) + .pQueueCreateInfos(pQueueCreateInfo) + .ppEnabledExtensionNames(ppExtensionNames) + + if instance.enabledLayers.contains(ValidationLayer) then + val ppValidationLayers = stack.callocPointer(1).put(stack.ASCII(ValidationLayer)) + pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip()) + + assert(vulkan13Features.synchronization2() || extensions.contains(SyncExtension)) + + val pDevice = stack.callocPointer(1) + check(vkCreateDevice(physicalDevice, pCreateInfo, null, pDevice), "Failed to create device") + new VkDevice(pDevice.get(0), physicalDevice, pCreateInfo) def get: VkDevice = device override protected def close(): Unit = vkDestroyDevice(device, null) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala index 3d1e8da4..4f7479e3 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala @@ -1,7 +1,7 @@ package io.computenode.cyfra.vulkan.core import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.VulkanContext.{SyncLayer, ValidationLayer} +import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayer import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObject import org.lwjgl.system.MemoryStack @@ -10,8 +10,6 @@ import org.lwjgl.vulkan.* import org.lwjgl.vulkan.EXTDebugReport.VK_EXT_DEBUG_REPORT_EXTENSION_NAME import org.lwjgl.vulkan.KHRPortabilityEnumeration.{VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR, VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME} import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VK13.* -import org.slf4j.LoggerFactory import java.nio.ByteBuffer import scala.collection.mutable @@ -21,11 +19,11 @@ import scala.util.chaining.* /** @author * MarconZet Created 13.04.2020 */ -object Instance { +object Instance: val ValidationLayersExtensions: Seq[String] = List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME) val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME) - lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack { stack => + lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack: stack => val ip = stack.ints(1) vkEnumerateInstanceLayerProperties(ip, null) val availableLayers = VkLayerProperties.malloc(ip.get(0), stack) @@ -38,16 +36,12 @@ object Instance { val extensions = instance_extensions.iterator().asScala.map(_.extensionNameString()) val layers = availableLayers.iterator().asScala.map(_.layerNameString()) (extensions.toSeq, layers.toSeq) - } lazy val version: Int = VK.getInstanceVersionSupported -} - -private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObject { - - private val instance: VkInstance = pushStack { stack => +private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObject: + private val instance: VkInstance = pushStack: stack => val appInfo = VkApplicationInfo .calloc(stack) .sType$Default() @@ -59,12 +53,11 @@ private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObj .apiVersion(Instance.version) val ppEnabledExtensionNames = getInstanceExtensions(stack) - val ppEnabledLayerNames = { + val ppEnabledLayerNames = val layers = enabledLayers val pointer = stack.callocPointer(layers.length) layers.foreach(x => pointer.put(stack.ASCII(x))) pointer.flip() - } val pCreateInfo = VkInstanceCreateInfo .calloc(stack) @@ -77,7 +70,6 @@ private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObj val pInstance = stack.mallocPointer(1) check(vkCreateInstance(pCreateInfo, null, pInstance), "Failed to create VkInstance") new VkInstance(pInstance.get(0), pCreateInfo) - } lazy val enabledLayers: Seq[String] = List .empty[String] @@ -94,19 +86,18 @@ private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObj override protected def close(): Unit = vkDestroyInstance(instance, null) - private def getInstanceExtensions(stack: MemoryStack) = { + private def getInstanceExtensions(stack: MemoryStack) = val n = stack.callocInt(1) check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, null)) val buffer = VkExtensionProperties.calloc(n.get(0), stack) check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, buffer)) - val availableExtensions = { + val availableExtensions = val buf = mutable.Buffer[String]() buffer.forEach { ext => buf.addOne(ext.extensionNameString()) } buf.toSet - } val extensions = mutable.Buffer.from(Instance.MoltenVkExtensions) if enableValidationLayers then extensions.addAll(Instance.ValidationLayersExtensions) @@ -120,5 +111,3 @@ private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObj val ppEnabledExtensionNames = stack.callocPointer(extensions.size) filteredExtensions.foreach(x => ppEnabledExtensionNames.put(stack.ASCII(x))) ppEnabledExtensionNames.flip() - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala index a55e4a36..823f129e 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala @@ -12,7 +12,7 @@ import org.lwjgl.vulkan.VK10.* import java.nio.ByteBuffer -private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferActions: Seq[BufferAction], context: VulkanContext) { +private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferActions: Seq[BufferAction], context: VulkanContext): protected val device: Device = context.device protected val queue: Queue = context.computeQueue protected val allocator: Allocator = context.allocator @@ -20,24 +20,21 @@ private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferAction protected val commandPool: CommandPool = context.commandPool protected val (descriptorSets, buffers) = setupBuffers() - private val commandBuffer: VkCommandBuffer = - pushStack { stack => - val commandBuffer = commandPool.createCommandBuffer() + private val commandBuffer: VkCommandBuffer = pushStack: stack => + val commandBuffer = commandPool.createCommandBuffer() + val commandBufferBeginInfo = VkCommandBufferBeginInfo + .calloc(stack) + .sType$Default() + .flags(0) - val commandBufferBeginInfo = VkCommandBufferBeginInfo - .calloc(stack) - .sType$Default() - .flags(0) - - check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") + check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") - recordCommandBuffer(commandBuffer) + recordCommandBuffer(commandBuffer) - check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") - commandBuffer - } + check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") + commandBuffer - def execute(input: Seq[ByteBuffer]): Seq[ByteBuffer] = { + def execute(input: Seq[ByteBuffer]): Seq[ByteBuffer] = val stagingBuffer = new Buffer( getBiggestTransportData * dataLength, @@ -46,13 +43,12 @@ private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferAction VMA_MEMORY_USAGE_UNKNOWN, allocator, ) - for i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadTo do { + for i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadTo do val buffer = input(i) Buffer.copyBuffer(buffer, stagingBuffer, buffer.remaining()) Buffer.copyBuffer(stagingBuffer, buffers(i), buffer.remaining(), commandPool).block().destroy() - } - pushStack { stack => + pushStack: stack => val fence = new Fence(device) val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer) val submitInfo = VkSubmitInfo @@ -62,29 +58,24 @@ private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferAction check(VK10.vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue") fence.block().destroy() - } - val output = for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadFrom) yield { + val output = for i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadFrom yield val fence = Buffer.copyBuffer(buffers(i), stagingBuffer, buffers(i).size, commandPool) val outBuffer = BufferUtils.createByteBuffer(buffers(i).size) fence.block().destroy() Buffer.copyBuffer(stagingBuffer, outBuffer, outBuffer.remaining()) outBuffer - } stagingBuffer.destroy() output - } - def destroy(): Unit = { + def destroy(): Unit = commandPool.freeCommandBuffer(commandBuffer) descriptorSets.foreach(_.destroy()) buffers.foreach(_.destroy()) - } protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit protected def getBiggestTransportData: Int -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala index 0d88ff41..e287a04b 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala @@ -1,24 +1,20 @@ package io.computenode.cyfra.vulkan.executor -import io.computenode.cyfra.vulkan.compute.* import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, Shader, UniformSize} +import io.computenode.cyfra.vulkan.compute.* import io.computenode.cyfra.vulkan.memory.{Buffer, DescriptorSet} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.pushStack import org.lwjgl.util.vma.Vma.* import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* import scala.collection.mutable -import scala.util.Using /** @author * MarconZet Created 15.04.2020 */ private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferAction], computePipeline: ComputePipeline, context: VulkanContext) - extends AbstractExecutor(dataLength, bufferActions, context) { + extends AbstractExecutor(dataLength, bufferActions, context): private lazy val shader: Shader = computePipeline.computeShader protected def getBiggestTransportData: Int = shader.layoutInfo.sets @@ -28,30 +24,27 @@ private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferActio } .max - protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) = pushStack { stack => + protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) = pushStack: stack => val bindings = shader.layoutInfo.sets.flatMap(_.bindings) val buffers = bindings.zipWithIndex.map { case (binding, i) => - val bufferSize = binding.size match { + val bufferSize = binding.size match case InputBufferSize(n) => n * dataLength case UniformSize(n) => n - } new Buffer(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | bufferActions(i).action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator) } val bufferDeque = mutable.ArrayDeque.from(buffers) val descriptorSetLayouts = computePipeline.descriptorSetLayouts - val descriptorSets = for (i <- descriptorSetLayouts.indices) yield { + val descriptorSets = for i <- descriptorSetLayouts.indices yield val descriptorSet = new DescriptorSet(device, descriptorSetLayouts(i)._1, descriptorSetLayouts(i)._2.bindings, descriptorPool) val size = descriptorSetLayouts(i)._2.bindings.size descriptorSet.update(bufferDeque.take(size).toSeq) bufferDeque.drop(size) descriptorSet - } (descriptorSets, buffers) - } protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit = - pushStack { stack => + pushStack: stack => vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get) val pDescriptorSets = stack.longs(descriptorSets.map(_.get)*) @@ -59,6 +52,3 @@ private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferActio val workgroup = shader.workgroupDimensions vkCmdDispatch(commandBuffer, dataLength / workgroup.x(), 1 / workgroup.y(), 1 / workgroup.z()) - } - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala index 86126ebd..1edacb32 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala @@ -1,17 +1,13 @@ package io.computenode.cyfra.vulkan.executor +import io.computenode.cyfra.utility.Utility.timed +import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.command.* import io.computenode.cyfra.vulkan.compute.* import io.computenode.cyfra.vulkan.core.* -import SequenceExecutor.* -import io.computenode.cyfra.utility.Utility.timed +import io.computenode.cyfra.vulkan.executor.SequenceExecutor.* import io.computenode.cyfra.vulkan.memory.* -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Queue} -import io.computenode.cyfra.vulkan.compute.{ComputePipeline, InputBufferSize, LayoutSet, UniformSize} import io.computenode.cyfra.vulkan.util.Util.* -import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer, DescriptorPool, DescriptorSet} import org.lwjgl.BufferUtils import org.lwjgl.util.vma.Vma.* import org.lwjgl.vulkan.* @@ -24,15 +20,16 @@ import java.nio.ByteBuffer /** @author * MarconZet Created 15.04.2020 */ -private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, context: VulkanContext) { +private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, context: VulkanContext): private val device: Device = context.device private val queue: Queue = context.computeQueue private val allocator: Allocator = context.allocator private val descriptorPool: DescriptorPool = context.descriptorPool private val commandPool: CommandPool = context.commandPool - private val pipelineToDescriptorSets: Map[ComputePipeline, Seq[DescriptorSet]] = pushStack { stack => - val pipelines = computeSequence.sequence.collect { case Compute(pipeline, _) => pipeline } + private val pipelineToDescriptorSets: Map[ComputePipeline, Seq[DescriptorSet]] = pushStack: stack => + val pipelines = computeSequence.sequence.collect: + case Compute(pipeline, _) => pipeline val rawSets = pipelines.map(_.computeShader.layoutInfo.sets) val numbered = rawSets.flatten.zipWithIndex @@ -71,11 +68,10 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont .toMap pipelines.zip(resolvedSets.map(_.map(descriptorSetMap(_)))).toMap - } private val descriptorSets = pipelineToDescriptorSets.toSeq.flatMap(_._2).distinctBy(_.get) - private def recordCommandBuffer(dataLength: Int): VkCommandBuffer = pushStack { stack => + private def recordCommandBuffer(dataLength: Int): VkCommandBuffer = pushStack: stack => val pipelinesHasDependencies = computeSequence.dependencies.map(_.to).toSet val commandBuffer = commandPool.createCommandBuffer() @@ -114,9 +110,8 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") commandBuffer - } - private def createBuffers(dataLength: Int): Map[DescriptorSet, Seq[Buffer]] = { + private def createBuffers(dataLength: Int): Map[DescriptorSet, Seq[Buffer]] = val setToActions = computeSequence.sequence .collect { case Compute(pipeline, bufferActions) => @@ -147,17 +142,19 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont .toMap setToBuffers - } - def execute(inputs: Seq[ByteBuffer], dataLength: Int): Seq[ByteBuffer] = pushStack { stack => + def execute(inputs: Seq[ByteBuffer], dataLength: Int): Seq[ByteBuffer] = pushStack: stack => timed("Vulkan full execute"): val setToBuffers = createBuffers(dataLength) def buffersWithAction(bufferAction: BufferAction): Seq[Buffer] = computeSequence.sequence.collect { case x: Compute => - pipelineToDescriptorSets(x.pipeline).map(setToBuffers).zip(x.pumpLayoutLocations).flatMap(x => x._1.zip(x._2)).collect { - case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer - } + pipelineToDescriptorSets(x.pipeline) + .map(setToBuffers) + .zip(x.pumpLayoutLocations) + .flatMap(x => x._1.zip(x._2)) + .collect: + case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer }.flatten val stagingBuffer = @@ -199,14 +196,11 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont setToBuffers.flatMap(_._2).foreach(_.destroy()) output - } def destroy(): Unit = descriptorSets.foreach(_.destroy()) -} - -object SequenceExecutor { +object SequenceExecutor: private[cyfra] case class ComputationSequence(sequence: Seq[ComputationStep], dependencies: Seq[Dependency]) private[cyfra] sealed trait ComputationStep @@ -218,5 +212,3 @@ object SequenceExecutor { case class LayoutLocation(set: Int, binding: Int) case class Dependency(from: ComputePipeline, fromSet: Int, to: ComputePipeline, toSet: Int) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala index 20b0b85e..147e1eda 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala @@ -3,16 +3,15 @@ package io.computenode.cyfra.vulkan.memory import io.computenode.cyfra.vulkan.core.{Device, Instance} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle -import org.lwjgl.system.MemoryStack import org.lwjgl.util.vma.Vma.{vmaCreateAllocator, vmaDestroyAllocator} import org.lwjgl.util.vma.{VmaAllocatorCreateInfo, VmaVulkanFunctions} /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Allocator(instance: Instance, device: Device) extends VulkanObjectHandle { +private[cyfra] class Allocator(instance: Instance, device: Device) extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val functions = VmaVulkanFunctions.calloc(stack) functions.set(instance.get, device.get) val allocatorInfo = VmaAllocatorCreateInfo @@ -25,8 +24,6 @@ private[cyfra] class Allocator(instance: Instance, device: Device) extends Vulka val pAllocator = stack.callocPointer(1) check(vmaCreateAllocator(allocatorInfo, pAllocator), "Failed to create allocator") pAllocator.get(0) - } def close(): Unit = vmaDestroyAllocator(handle) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala index 91c27ec1..53b364a7 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala @@ -1,26 +1,22 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.command.{CommandPool, Fence} -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.PointerBuffer -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.system.MemoryUtil.* import org.lwjgl.util.vma.Vma.* import org.lwjgl.util.vma.VmaAllocationCreateInfo import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo, VkCommandBuffer} +import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo} -import java.nio.{ByteBuffer, LongBuffer} -import scala.util.Using +import java.nio.ByteBuffer /** @author * MarconZet Created 11.05.2019 */ -private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: Int, val allocator: Allocator) extends VulkanObjectHandle { +private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: Int, val allocator: Allocator) extends VulkanObjectHandle: - val (handle, allocation) = pushStack { stack => + val (handle, allocation) = pushStack: stack => val bufferInfo = VkBufferCreateInfo .calloc(stack) .sType$Default() @@ -39,42 +35,37 @@ private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: val pAllocation = stack.callocPointer(1) check(vmaCreateBuffer(allocator.get, bufferInfo, allocInfo, pBuffer, pAllocation, null), "Failed to create buffer") (pBuffer.get(), pAllocation.get()) - } - def get(dst: Array[Byte]): Unit = { + def get(dst: Array[Byte]): Unit = val len = Math.min(dst.length, size) val byteBuffer = memCalloc(len) Buffer.copyBuffer(this, byteBuffer, len) byteBuffer.get(dst) memFree(byteBuffer) - } protected def close(): Unit = vmaDestroyBuffer(allocator.get, handle, allocation) -} -object Buffer { +object Buffer: def copyBuffer(src: ByteBuffer, dst: Buffer, bytes: Long): Unit = - pushStack { stack => + pushStack: stack => val pData = stack.callocPointer(1) check(vmaMapMemory(dst.allocator.get, dst.allocation, pData), "Failed to map destination buffer memory") val data = pData.get() memCopy(memAddress(src), data, bytes) vmaFlushAllocation(dst.allocator.get, dst.allocation, 0, bytes) vmaUnmapMemory(dst.allocator.get, dst.allocation) - } def copyBuffer(src: Buffer, dst: ByteBuffer, bytes: Long): Unit = - pushStack { stack => + pushStack: stack => val pData = stack.callocPointer(1) check(vmaMapMemory(src.allocator.get, src.allocation, pData), "Failed to map destination buffer memory") val data = pData.get() memCopy(data, memAddress(dst), bytes) vmaUnmapMemory(src.allocator.get, src.allocation) - } def copyBuffer(src: Buffer, dst: Buffer, bytes: Long, commandPool: CommandPool): Fence = - pushStack { stack => + pushStack: stack => val commandBuffer = commandPool.beginSingleTimeCommands() val copyRegion = VkBufferCopy @@ -85,6 +76,3 @@ object Buffer { vkCmdCopyBuffer(commandBuffer, src.get, dst.get, copyRegion) commandPool.endSingleTimeCommands(commandBuffer) - } - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala index b8c83398..f6ceced3 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala @@ -1,25 +1,19 @@ package io.computenode.cyfra.vulkan.memory -import DescriptorPool.MAX_SETS -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.memory.DescriptorPool.MAX_SETS +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.{VkDescriptorPoolCreateInfo, VkDescriptorPoolSize} -import java.nio.LongBuffer -import scala.util.Using - /** @author * MarconZet Created 14.04.2019 */ -object DescriptorPool { +object DescriptorPool: val MAX_SETS = 100 -} -private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val descriptorPoolSize = VkDescriptorPoolSize.calloc(1, stack) descriptorPoolSize .get(0) @@ -36,8 +30,6 @@ private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle { val pDescriptorPool = stack.callocLong(1) check(vkCreateDescriptorPool(device.get, descriptorPoolCreateInfo, null, pDescriptorPool), "Failed to create descriptor pool") pDescriptorPool.get() - } override protected def close(): Unit = vkDestroyDescriptorPool(device.get, handle, null) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala index ef91eed4..e9f49d4d 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala @@ -1,10 +1,9 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.compute.{Binding, LayoutSet} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.compute.Binding import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle -import org.lwjgl.system.MemoryStack import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.{VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, VkWriteDescriptorSet} @@ -12,9 +11,9 @@ import org.lwjgl.vulkan.{VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, Vk * MarconZet Created 15.04.2020 */ private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, val bindings: Seq[Binding], descriptorPool: DescriptorPool) - extends VulkanObjectHandle { + extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val pSetLayout = stack.callocLong(1).put(0, descriptorSetLayout) val descriptorSetAllocateInfo = VkDescriptorSetAllocateInfo .calloc(stack) @@ -25,9 +24,8 @@ private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, va val pDescriptorSet = stack.callocLong(1) check(vkAllocateDescriptorSets(device.get, descriptorSetAllocateInfo, pDescriptorSet), "Failed to allocate descriptor set") pDescriptorSet.get() - } - def update(buffers: Seq[Buffer]): Unit = pushStack { stack => + def update(buffers: Seq[Buffer]): Unit = pushStack: stack => val writeDescriptorSet = VkWriteDescriptorSet.calloc(buffers.length, stack) buffers.indices foreach { i => val descriptorBufferInfo = VkDescriptorBufferInfo @@ -48,8 +46,6 @@ private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, va .pBufferInfo(descriptorBufferInfo) } vkUpdateDescriptorSets(device.get, writeDescriptorSet, null) - } override protected def close(): Unit = vkFreeDescriptorSets(device.get, descriptorPool.get, handle) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala index 9065f490..fcdb71aa 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala @@ -5,7 +5,6 @@ import org.lwjgl.vulkan.VK10.VK_SUCCESS import scala.util.Using -object Util { +object Util: def pushStack[T](f: MemoryStack => T): T = Using(MemoryStack.stackPush())(f).get def check(err: Int, message: String = ""): Unit = if err != VK_SUCCESS then throw new VulkanAssertionError(message, err) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala index 3326bf0d..df8a75a0 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala @@ -12,7 +12,7 @@ import org.lwjgl.vulkan.VK10.* private[cyfra] class VulkanAssertionError(msg: String, result: Int) extends AssertionError(s"$msg: ${VulkanAssertionError.translateVulkanResult(result)}") -object VulkanAssertionError { +object VulkanAssertionError: def translateVulkanResult(result: Int): String = result match // Success codes @@ -69,4 +69,3 @@ object VulkanAssertionError { "A validation layer found an error." case x => s"Unknown $x" -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala index efecc480..b896706b 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala @@ -3,15 +3,12 @@ package io.computenode.cyfra.vulkan.util /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] abstract class VulkanObject { +private[cyfra] abstract class VulkanObject: protected var alive: Boolean = true - def destroy(): Unit = { + def destroy(): Unit = if !alive then throw new IllegalStateException() close() alive = false - } protected def close(): Unit - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala index 9ea98538..acc448c7 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala @@ -3,10 +3,9 @@ package io.computenode.cyfra.vulkan.util /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] abstract class VulkanObjectHandle extends VulkanObject { +private[cyfra] abstract class VulkanObjectHandle extends VulkanObject: protected val handle: Long def get: Long = if !alive then throw new IllegalStateException() else handle -}