Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ lazy val commonSettings = Seq(
"org.lwjgl" % "lwjgl-vma" % lwjglVersion classifier lwjglNatives,
"org.joml" % "joml" % jomlVersion,
"commons-io" % "commons-io" % "2.16.1",
"org.slf4j" % "slf4j-api" % "1.7.30",
"org.slf4j" % "slf4j-simple" % "1.7.30" % Test,
"org.scalameta" % "munit_3" % "1.0.0" % Test,
"com.lihaoyi" %% "sourcecode" % "0.4.3-M5",
"org.slf4j" % "slf4j-api" % "2.0.17",
"org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3" % Test,
) ++ vulkanNatives,
)

Expand All @@ -60,6 +59,10 @@ lazy val runnerSettings = Seq(libraryDependencies += "org.apache.logging.log4j"
lazy val utility = (project in file("cyfra-utility"))
.settings(commonSettings)

lazy val spirvTools = (project in file("cyfra-spirv-tools"))
.settings(commonSettings)
.dependsOn(utility)

lazy val vulkan = (project in file("cyfra-vulkan"))
.settings(commonSettings)
.dependsOn(utility)
Expand All @@ -74,7 +77,7 @@ lazy val compiler = (project in file("cyfra-compiler"))

lazy val runtime = (project in file("cyfra-runtime"))
.settings(commonSettings)
.dependsOn(compiler, dsl, vulkan, utility)
.dependsOn(compiler, dsl, vulkan, utility, spirvTools)

lazy val foton = (project in file("cyfra-foton"))
.settings(commonSettings)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,23 @@ import io.computenode.cyfra.*
import io.computenode.cyfra.dsl.collections.GSeq
import io.computenode.cyfra.dsl.control.Pure.pure
import io.computenode.cyfra.dsl.struct.GStruct.Empty
import io.computenode.cyfra.runtime.{GContext, GFunction}
import org.apache.commons.io.IOUtils
import org.junit.runner.RunWith
import io.computenode.cyfra.e2e.ImageTests
import io.computenode.cyfra.runtime.mem.Vec4FloatMem
import io.computenode.cyfra.runtime.{GContext, GFunction}
import io.computenode.cyfra.spirvtools.*
import io.computenode.cyfra.spirvtools.SpirvTool.{Param, ToFile}
import io.computenode.cyfra.utility.ImageUtility
import munit.FunSuite

import java.io.File
import java.nio.file.Files
import java.nio.file.Paths
import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext}
import io.computenode.cyfra.e2e.ImageTests

class JuliaSet extends FunSuite:
given GContext = new GContext()
given ExecutionContext = Implicits.global

test("Render julia set"):
def runJuliaSet(referenceImgName: String)(using GContext): Unit = {
val dim = 4096
val max = 1
val RECURSION_LIMIT = 1000
Expand Down Expand Up @@ -70,5 +68,22 @@ class JuliaSet extends FunSuite:
val r = Vec4FloatMem(dim * dim).map(function).asInstanceOf[Vec4FloatMem].toArray
val outputTemp = File.createTempFile("julia", ".png")
ImageUtility.renderToImage(r, dim, outputTemp.toPath)
val referenceImage = getClass.getResource("julia.png")
val referenceImage = getClass.getResource(referenceImgName)
ImageTests.assertImagesEquals(outputTemp, new File(referenceImage.getPath))
}

test("Render julia set"):
given GContext = new GContext
runJuliaSet("/julia.png")

test("Render julia set optimized"):
given GContext = new GContext(
SpirvToolsRunner(
validator = SpirvValidator.Enable(throwOnFail = true),
optimizer = SpirvOptimizer.Enable(toolOutput = ToFile(Paths.get("output/optimized.spv")), settings = Seq(Param("-O"))),
disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/optimized.spvasm")), throwOnFail = true),
crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl")), throwOnFail = true),
originalSpirvOutput = ToFile(Paths.get("output/original.spv")),
),
)
runJuliaSet("/julia_O_optimized.png")
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
package io.computenode.cyfra.runtime

import io.computenode.cyfra.dsl.{*, given}
import Value.{Float32, Int32, Vec4}
import io.computenode.cyfra.vulkan.VulkanContext
import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader, UniformSize}
import io.computenode.cyfra.vulkan.executor.{BufferAction, SequenceExecutor}
import SequenceExecutor.*
import io.computenode.cyfra.dsl.Value
import io.computenode.cyfra.dsl.Value.{Float32, FromExpr, Int32, Vec4}
import io.computenode.cyfra.dsl.collections.GArray
import io.computenode.cyfra.dsl.struct.*
import io.computenode.cyfra.dsl.struct.GStruct.*
import io.computenode.cyfra.runtime.mem.GMem.totalStride
import io.computenode.cyfra.runtime.mem.{FloatMem, GMem, IntMem, Vec4FloatMem}
import io.computenode.cyfra.spirv.SpirvTypes.typeStride
import io.computenode.cyfra.spirv.compilers.DSLCompiler
import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.{UniformStructRef, WorkerIndex}
import mem.{FloatMem, GMem, IntMem, Vec4FloatMem}
import org.lwjgl.system.{Configuration, MemoryUtil}
import io.computenode.cyfra.spirvtools.SpirvToolsRunner
import io.computenode.cyfra.vulkan.VulkanContext
import io.computenode.cyfra.vulkan.compute.*
import io.computenode.cyfra.vulkan.executor.SequenceExecutor.*
import io.computenode.cyfra.vulkan.executor.{BufferAction, SequenceExecutor}
import izumi.reflect.Tag
import org.lwjgl.system.Configuration

import java.io.FileOutputStream
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.util.concurrent.Executors
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}

class GContext:

class GContext(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()):
Configuration.STACK_SIZE.set(1024) // fix lwjgl stack size

val vkContext = new VulkanContext()
Expand All @@ -38,20 +34,17 @@ class GContext:
val uniformStruct = uniformStructSchema.fromTree(UniformStructRef)
val tree = function.fn
.apply(uniformStruct, WorkerIndex, GArray[H](0))
val shaderCode = DSLCompiler.compile(tree, function.arrayInputs, function.arrayOutputs, uniformStructSchema)
dumpSpvToFile(shaderCode, "program.spv") // TODO remove before release

val optimizedShaderCode =
spirvToolsRunner.processShaderCodeWithSpirvTools(DSLCompiler.compile(tree, function.arrayInputs, function.arrayOutputs, uniformStructSchema))

val inOut = 0 to 1 map (Binding(_, InputBufferSize(typeStride(summon[Tag[H]]))))
val uniform = Option.when(uniformStructSchema.fields.nonEmpty)(Binding(2, UniformSize(totalStride(uniformStructSchema))))
val layoutInfo = LayoutInfo(Seq(LayoutSet(0, inOut ++ uniform)))
val shader = new Shader(shaderCode, new org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device)
new ComputePipeline(shader, vkContext)
}

private def dumpSpvToFile(code: ByteBuffer, path: String): Unit =
val fc: FileChannel = new FileOutputStream("program.spv").getChannel
fc.write(code)
fc.close()
code.rewind()
val shader = Shader(optimizedShaderCode, org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device)
ComputePipeline(shader, vkContext)
}

def execute[G <: GStruct[G]: Tag: GStructSchema, H <: Value, R <: Value](mem: GMem[H], fn: GFunction[G, H, R])(using
uniformContext: UniformContext[G],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.computenode.cyfra.spirvtools

import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd
import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger}
import io.computenode.cyfra.utility.Logger.logger

import java.nio.ByteBuffer

object SpirvCross extends SpirvTool("spirv-cross"):

def crossCompileSpirv(shaderCode: ByteBuffer, crossCompilation: CrossCompilation): Option[String] =
crossCompilation match
case Enable(throwOnFail, toolOutput, params) =>
val crossCompilationRes = tryCrossCompileSpirv(shaderCode, params)
crossCompilationRes match
case Left(err) if throwOnFail => throw err
case Left(err) =>
logger.warn(err.message)
None
case Right(crossCompiledCode) =>
toolOutput match
case Ignore =>
case toFile @ SpirvTool.ToFile(_) =>
toFile.write(crossCompiledCode)
logger.debug(s"Saved cross compiled shader code in ${toFile.filePath}.")
case ToLogger => logger.debug(s"SPIR-V Cross Compilation result:\n$crossCompiledCode")
Some(crossCompiledCode)
case Disable =>
logger.debug("SPIR-V cross compilation is disabled.")
None

private def tryCrossCompileSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] =
val cmd = Seq(toolName) ++ Seq("-") ++ params.flatMap(_.asStringParam.split(" "))
for
(stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
result <- Either.cond(
exitCode == 0, {
logger.debug("SPIR-V cross compilation succeeded.")
stdout.toString
},
SpirvToolCrossCompilationFailed(exitCode, stderr.toString),
)
yield result

sealed trait CrossCompilation

case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty)
extends CrossCompilation

final case class SpirvToolCrossCompilationFailed(exitCode: Int, stderr: String) extends SpirvToolError:
def message: String =
s"""SPIR-V cross compilation failed with exit code $exitCode.
|Cross errors:
|$stderr""".stripMargin

case object Disable extends CrossCompilation
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package io.computenode.cyfra.spirvtools

import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger}
import io.computenode.cyfra.utility.Logger.logger

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.nio.file.Files

object SpirvDisassembler extends SpirvTool("spirv-dis"):

def disassembleSpirv(shaderCode: ByteBuffer, disassembly: Disassembly): Option[String] =
disassembly match
case Enable(throwOnFail, toolOutput, params) =>
val disassemblyResult = tryGetDisassembleSpirv(shaderCode, params)
disassemblyResult match
case Left(err) if throwOnFail => throw err
case Left(err) =>
logger.warn(err.message)
None
case Right(disassembledShader) =>
toolOutput match
case Ignore =>
case toFile @ SpirvTool.ToFile(_) =>
toFile.write(disassembledShader)
logger.debug(s"Saved disassembled shader code in ${toFile.filePath}.")
case ToLogger => logger.debug(s"SPIR-V Assembly:\n$disassembledShader")
Some(disassembledShader)
case Disable =>
logger.debug("SPIR-V disassembly is disabled.")
None

private def tryGetDisassembleSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] =
val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-")
for
(stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
result <- Either.cond(
exitCode == 0, {
logger.debug("SPIR-V disassembly succeeded.")
stdout.toString
},
SpirvToolDisassemblyFailed(exitCode, stderr.toString),
)
yield result

sealed trait Disassembly

final case class SpirvToolDisassemblyFailed(exitCode: Int, stderr: String) extends SpirvToolError:
def message: String =
s"""SPIR-V disassembly failed with exit code $exitCode.
|Disassembly errors:
|$stderr""".stripMargin

case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty)
extends Disassembly

case object Disable extends Disassembly
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.computenode.cyfra.spirvtools

import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd
import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile}
import io.computenode.cyfra.utility.Logger.logger

import java.nio.ByteBuffer

object SpirvOptimizer extends SpirvTool("spirv-opt"):

def optimizeSpirv(shaderCode: ByteBuffer, optimization: Optimization): Option[ByteBuffer] =
optimization match
case Enable(throwOnFail, toolOutput, params) =>
val optimizationRes = tryGetOptimizeSpirv(shaderCode, params)
optimizationRes match
case Left(err) if throwOnFail => throw err
case Left(err) =>
logger.warn(err.message)
None
case Right(optimizedShaderCode) =>
toolOutput match
case SpirvTool.Ignore =>
case toFile @ SpirvTool.ToFile(_) =>
toFile.write(optimizedShaderCode)
logger.debug(s"Saved optimized shader code in ${toFile.filePath}.")
Some(optimizedShaderCode)
case Disable =>
logger.debug("SPIR-V optimization is disabled.")
None

private def tryGetOptimizeSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, ByteBuffer] =
val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-", "-o", "-")
for
(stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
result <- Either.cond(
exitCode == 0, {
logger.debug("SPIR-V optimization succeeded.")
val optimized = toDirectBuffer(ByteBuffer.wrap(stdout.toByteArray))
optimized
},
SpirvToolOptimizationFailed(exitCode, stderr.toString),
)
yield result

private def toDirectBuffer(buf: ByteBuffer): ByteBuffer =
val direct = ByteBuffer.allocateDirect(buf.remaining())
direct.put(buf)
direct.flip()
direct

sealed trait Optimization

case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type = Ignore, settings: Seq[Param] = Seq.empty) extends Optimization

final case class SpirvToolOptimizationFailed(exitCode: Int, stderr: String) extends SpirvToolError:
def message: String =
s"""SPIR-V optimization failed with exit code $exitCode.
|Optimizer errors:
|$stderr""".stripMargin

case object Disable extends Optimization
Loading