diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c47d94a8..64d5109b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,6 +3,7 @@ on: push: branches: - main + - dev tags: - "v*" pull_request: diff --git a/build.sbt b/build.sbt index 44acec5f..836e50e3 100644 --- a/build.sbt +++ b/build.sbt @@ -96,13 +96,17 @@ lazy val vscode = (project in file("cyfra-vscode")) .settings(commonSettings) .dependsOn(foton) +lazy val interpreter = (project in file("cyfra-interpreter")) + .settings(commonSettings) + .dependsOn(dsl, compiler) + lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(commonSettings, runnerSettings) - .dependsOn(runtime) + .dependsOn(runtime, interpreter) lazy val root = (project in file(".")) .settings(name := "Cyfra") - .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples) + .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, interpreter) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala new file mode 100644 index 00000000..27483c16 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* +import Value.FromExpr.fromExpr, control.Scope + +class InterpreterE2eTest extends munit.FunSuite: + test("interpret should not stack overflow"): + val pure = Pure(0) + var gio = FlatMap(pure, pure) + for _ <- 0 until 1000000 do gio = FlatMap(pure, gio) + val result = Interpreter.interpret(gio, 0) + val res = 0 + val exp = 0 + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala new file mode 100644 index 00000000..c561c59c --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -0,0 +1,85 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given}, binding.{ReadBuffer, GBuffer} +import Value.FromExpr.fromExpr, control.Scope +import izumi.reflect.Tag + +class SimulateE2eTest extends munit.FunSuite: + test("simulate binary operation arithmetic"): + val a: Int32 = 1 + val b: Int32 = 2 + val c: Int32 = 3 + val d: Int32 = 4 + val e: Int32 = 5 + val f: Int32 = 6 + val e1 = Diff(a, b) + val e2 = Sum(fromExpr(e1), c) + val e3 = Mul(f, fromExpr(e2)) + val e4 = Div(fromExpr(e3), d) + val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4) + val (result, _) = Simulate.sim(expr, SimContext()) + val expected = 2 + assert(result == expected, s"Expected $expected, got $result") + + test("simulate vec4, scalar, dot, extract scalar"): + val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) + val (res1, _) = Simulate.sim(v1, SimContext()) + val exp1 = Vector(1f, 2f, 3f, 4f) + assert(res1 == exp1, s"Expected $exp1, got $res1") + + val i: Int32 = 2 + val expr = ExtractScalar(fromExpr(v1), i) + val (res, _) = Simulate.sim(expr, SimContext()) + val exp = 3f + assert(res == exp, s"Expected $exp, got $res") + + val v2 = ScalarProd(fromExpr(v1), -1f) + val (res2, _) = Simulate.sim(v2, SimContext()) + val exp2 = Vector(-1f, -2f, -3f, -4f) + assert(res2 == exp2, s"Expected $exp2, got $res2") + + val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f) + val dot = DotProd(fromExpr(v1), fromExpr(v3)) + val (res3a, _) = Simulate.sim(dot, SimContext()) + val res3 = res3a.asInstanceOf[Float] + val exp3 = 0f + assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3") + + test("simulate bitwise ops"): + val a: Int32 = 5 + val by: UInt32 = 3 + val aNot = BitwiseNot(a) + val left = ShiftLeft(fromExpr(aNot), by) + val right = ShiftRight(fromExpr(aNot), by) + val and = BitwiseAnd(fromExpr(left), fromExpr(right)) + val or = BitwiseOr(fromExpr(left), fromExpr(right)) + val xor = BitwiseXor(fromExpr(and), fromExpr(or)) + + val (res, _) = Simulate.sim(xor, SimContext()) + val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) + assert(res == exp, s"Expected $exp, got $res") + + test("simulate should not stack overflow"): + val a: Int32 = 1 + var sum = Sum(a, a) // 2 + for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) + val (res, _) = Simulate.sim(sum, SimContext()) + val exp = 1000002 + assert(res == exp, s"Expected $exp, got $res") + + test("simulate ReadBuffer"): + // We fake a GBuffer with an array + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (0 until 1024).toArray[Result] + + val sc = SimContext().addBuffer(buffer, array) + + val expr = ReadBuffer(buffer, 128) + val (res, newSc) = Simulate.sim(expr, sc) + val exp = 128 + assert(res == exp, s"Expected $exp, got $res") + + // the context should keep track of the read + assert(newSc.reads.contains(ReadBuf(buffer, 128)), "missing read") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala new file mode 100644 index 00000000..0026ebf5 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -0,0 +1,93 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import Value.FromExpr.fromExpr, control.Scope, binding.{GBuffer, ReadBuffer} +import izumi.reflect.Tag + +class SimulateWhenE2eTest extends munit.FunSuite: + test("simulate when"): + val expr1 = WhenExpr( + when = 2 >= 1, // true + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val (res1, _) = Simulate.sim(expr1, SimContext()) + val exp1 = 1 + assert(res1 == exp1, s"Expected $exp1, got $res1") + + test("simulate elseWhen first"): + val expr2 = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 >= 2)) /*true*/, Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val (res2, _) = Simulate.sim(expr2, SimContext()) + val exp2 = 2 + assert(res2 == exp2, s"Expected $exp2, got $res2") + + test("simulate elseWhen second"): + val expr3 = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 <= 3))), // true + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val (res3, _) = Simulate.sim(expr3, SimContext()) + val exp3 = 4 + assert(res3 == exp3, s"Expected $exp3, got $res3") + + test("simulate otherwise"): + val expr4 = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 >= 3))), // false + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val (res4, _) = Simulate.sim(expr4, SimContext()) + val exp4 = 3 + assert(res4 == exp4, s"Expected $exp4, got $res4") + + test("simulate mixed arithmetic, reads and when"): + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (128 until 0 by -1).toArray[Result] + + val sc = SimContext().addBuffer(buffer, array) + + val a: Int32 = 32 + val b: Int32 = 64 + val c: Int32 = 4 + + val readExpr1 = ReadBuffer(buffer, a) // 96 + val expr1 = Mul(c, fromExpr(readExpr1)) // 4 * 96 = 384 + + val readExpr2 = ReadBuffer(buffer, b) // 64 + val expr2 = Sum(c, fromExpr(readExpr2)) // 4 + 64 = 68 + + val expr3 = Mod(fromExpr(expr2), 5) // 68 % 5 = 3 + + val cond1 = fromExpr(expr1) <= fromExpr(expr2) // 384 <= 68 false + val cond2 = Equal(fromExpr(expr1), fromExpr(expr2)) // 384 == 68 false + val cond3 = GreaterThanEqual(fromExpr(expr3), fromExpr(expr2)) // 3 >= 68 false + + val expr = WhenExpr( + when = cond1, // false + thenCode = Scope(expr1), + otherConds = List(Scope(cond2), Scope(cond3)), // false false + otherCaseCodes = List(Scope(expr1), Scope(expr2)), // 384, 68 + otherwise = Scope(expr3), // 3 + ) + val (res, newSc) = Simulate.sim(expr, sc) + val exp = 3 + assert(res == exp, s"Expected $exp, got $res") + + // There should be 2 reads in the simulation context + assert(newSc.reads.contains(ReadBuf(buffer, 32))) + assert(newSc.reads.contains(ReadBuf(buffer, 64))) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala new file mode 100644 index 00000000..68fcb8af --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -0,0 +1,59 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* +import Result.Result +import izumi.reflect.Tag + +case class InvocResult( + invocId: Int, + instructions: List[Expression[?]] = Nil, + values: List[Result] = Nil, + writes: List[Writes] = Nil, + reads: List[Reads] = Nil, +) + +case class InterpretResult(invocs: List[InvocResult] = Nil) + +object Interpreter: + private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match + case Pure(value) => + val (result, newSc) = Simulate.sim(value.asInstanceOf[Value], sc) // TODO needs fixing + newSc.addResult(result) + + private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match + case WriteBuffer(buffer, index, value) => + val (n, _) = Simulate.sim(index, SimContext()) // Int32, no reads/writes here, don't need resulting context + val i = n.asInstanceOf[Int] + val (res, newSc) = Simulate.sim(value, sc) + newSc.addWrite(WriteBuf(buffer, i, res)) + + private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match + case WriteUniform(uniform, value) => + val (result, newSc) = Simulate.sim(value.asInstanceOf[Value], sc) // TODO needs fixing + newSc.addWrite(WriteUni(uniform, result)) + + private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match + case p: Pure[?] => interpretPure(p, sc) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, sc) + case wu: WriteUniform[?] => interpretWriteUniform(wu, sc) + case _ => throw IllegalArgumentException("interpretOne: invalid GIO") + + @annotation.tailrec + private def interpretMany(gios: List[GIO[?]], sc: SimContext): SimContext = gios match + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, sc) + case Repeat(n, f) :: tail => + val (i, _) = Simulate.sim(n, SimContext()) // just Int32, no reads/writes + val int = i.asInstanceOf[Int] + val newGios = (0 until int).map(i => f(i)).toList + interpretMany(newGios ::: tail, sc) + case head :: tail => + val newSc = interpretOne(head, sc) + interpretMany(tail, newSc) + case Nil => sc + + def interpret(gio: GIO[?], invocId: Int): InvocResult = + val sc = interpretMany(List(gio), SimContext()) + InvocResult(invocId, ???, sc.values, sc.writes, sc.reads) + + def interpret(gio: GIO[?], invocIds: List[Int]): InterpretResult = InterpretResult(invocIds.map(interpret(gio, _))) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala new file mode 100644 index 00000000..ad3d7634 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala @@ -0,0 +1,94 @@ +package io.computenode.cyfra.interpreter + +object Result: + export ScalarResult.*, VectorResult.* + + type Result = ScalarRes | Vector[ScalarRes] + + extension (r: Result) + def negate: Result = r match + case s: ScalarRes => s.neg + case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd + + def bitNeg: Int = r match + case sr: ScalarRes => ~sr + case _ => throw IllegalArgumentException("bitNeg: wrong argument type") + + def shiftLeft(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n << b + case _ => throw IllegalArgumentException("shiftLeft: incompatible argument types") + + def shiftRight(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n >> b + case _ => throw IllegalArgumentException("shiftRight: incompatible argument types") + + def bitAnd(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s & t + case _ => throw IllegalArgumentException("bitAnd: incompatible argument types") + + def bitOr(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s | t + case _ => throw IllegalArgumentException("bitOr: incompatible argument types") + + def bitXor(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s ^ t + case _ => throw IllegalArgumentException("bitXor: incompatible argument types") + + def add(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s + t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t + case _ => throw IllegalArgumentException("add: incompatible argument types") + + def sub(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s - t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t + case _ => throw IllegalArgumentException("sub: incompatible argument types") + + def mul(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s * t + case _ => throw IllegalArgumentException("mul: incompatible argument types") + + def div(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s / t + case _ => throw IllegalArgumentException("div: incompatible argument types") + + def mod(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s % t + case _ => throw IllegalArgumentException("mod: incompatible argument types") + + def scale(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: ScalarRes) => v scale t + case _ => throw IllegalArgumentException("scale: incompatible argument types") + + def dot(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v dot t + case _ => throw IllegalArgumentException("dot: incompatible argument types") + + def &&(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + def ||(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + def gt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr > t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def lt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr < t + case _ => throw IllegalArgumentException("lt: incompatible argument types") + + def gteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr >= t + case _ => throw IllegalArgumentException("gteq: incompatible argument types") + + def lteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr <= t + case _ => throw IllegalArgumentException("lteq: incompatible argument types") + + def eql(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr === t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t + case _ => throw IllegalArgumentException("eql: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala new file mode 100644 index 00000000..8d3e537f --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala @@ -0,0 +1,92 @@ +package io.computenode.cyfra.interpreter + +object ScalarResult: + type ScalarRes = Float | Int | Boolean + + extension (sr: ScalarRes) + def neg: ScalarRes = sr match + case f: Float => -f + case n: Int => -n + case b: Boolean => !b + + infix def unary_~ : Int = sr match + case n: Int => ~n + case _ => throw IllegalArgumentException("~: wrong argument type") + + infix def <<(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n << b + case _ => throw IllegalArgumentException("<<: incompatible argument types") + + infix def >>(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n >> b + case _ => throw IllegalArgumentException(">>: incompatible argument types") + + infix def &(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m & n + case _ => throw IllegalArgumentException("&: incompatible argument types") + + infix def |(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m | n + case _ => throw IllegalArgumentException("|: incompatible argument types") + + infix def ^(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m ^ n + case _ => throw IllegalArgumentException("^: incompatible argument types") + + infix def +(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f + t + case (n: Int, t: Int) => n + t + case _ => throw IllegalArgumentException("+: incompatible argument types") + + infix def -(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f - t + case (n: Int, t: Int) => n - t + case _ => throw IllegalArgumentException("-: incompatible argument types") + + infix def *(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f * t + case (n: Int, t: Int) => n * t + case _ => throw IllegalArgumentException("*: incompatible argument types") + + infix def /(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f / t + case (n: Int, t: Int) => n / t + case _ => throw IllegalArgumentException("/: incompatible argument types") + + infix def %(that: ScalarRes): Int = (sr, that) match + case (n: Int, t: Int) => n % t + case _ => throw IllegalArgumentException("%: incompatible argument types") + + infix def &&(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + infix def ||(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + infix def >(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f > t + case (n: Int, t: Int) => n > t + case _ => throw IllegalArgumentException(">: incompatible argument types") + + infix def <(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f < t + case (n: Int, t: Int) => n < t + case _ => throw IllegalArgumentException("<: incompatible argument types") + + infix def >=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f >= t + case (n: Int, t: Int) => n >= t + case _ => throw IllegalArgumentException(">=: incompatible argument types") + + infix def <=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f <= t + case (n: Int, t: Int) => n <= t + case _ => throw IllegalArgumentException("<=: incompatible argument types") + + infix def ===(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => Math.abs(f - t) < 0.001f + case (n: Int, t: Int) => n == t + case (b: Boolean, t: Boolean) => b == t + case _ => throw IllegalArgumentException("===: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala new file mode 100644 index 00000000..33917f94 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -0,0 +1,41 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} +import Result.Result + +enum Reads: + case ReadBuf(buffer: GBuffer[?], index: Int) + case ReadUni(uniform: GUniform[?]) +export Reads.* + +enum Writes: + case WriteBuf(buffer: GBuffer[?], index: Int, value: Result) + case WriteUni(uni: GUniform[?], value: Result) +export Writes.* + +case class SimContext( + bufMap: Map[GBuffer[?], Array[Result]] = Map(), + uniMap: Map[GUniform[?], Result] = Map(), + values: List[Result] = Nil, + writes: List[Writes] = Nil, + reads: List[Reads] = Nil, +): + def addBuffer(buffer: GBuffer[?], array: Array[Result]): SimContext = copy(bufMap = bufMap + (buffer -> array)) + + def addRead(read: Reads): SimContext = read match + case ReadBuf(buffer, index) => copy(reads = ReadBuf(buffer, index) :: reads) + case ReadUni(uniform) => copy(reads = ReadUni(uniform) :: reads) + + def addWrite(write: Writes): SimContext = write match + case WriteBuf(buffer, index, value) => + val newArray = bufMap(buffer).updated(index, value) + val newWrites = WriteBuf(buffer, index, value) :: writes + copy(bufMap = bufMap.updated(buffer, newArray), writes = newWrites) + case WriteUni(uni, value) => + val newWrites = WriteUni(uni, value) :: writes + copy(uniMap = uniMap.updated(uni, value), writes = newWrites) + + def addResult(res: Result) = copy(values = res :: values) + def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) + def lookupUni(uniform: GUniform[?]): Result = uniMap(uniform) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala new file mode 100644 index 00000000..2f691f6c --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -0,0 +1,174 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, macros.FnCall.FnIdentifier, control.Scope +import collections.*, GArray.GArrayElem, GSeq.{CurrentElem, AggregateElem, FoldSeq} +import struct.*, GStruct.{ComposeStruct, GetField} +import io.computenode.cyfra.spirv.BlockBuilder.buildBlock +import collection.mutable.Map as MMap + +object Simulate: + import Result.* + + def sim(v: Value, sc: SimContext): (Result, SimContext) = sim(v.tree, sc) + def sim(e: Expression[?], sc: SimContext = SimContext()): (Result, SimContext) = simIterate(buildBlock(e), sc)(using Map()) + + @annotation.tailrec + def simIterate(blocks: List[Expression[?]], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = blocks match + case head :: Nil => simOne(head, sc) + case head :: next => + val (result, newSc) = simOne(head, sc) // context updated if there are reads/writes + val newExprMap = exprMap + (head.treeid -> result) // update map with new result + simIterate(next, newSc)(using newExprMap) + case Nil => ??? // should not happen + + def simOne(e: Expression[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = e match + case e: PhantomExpression[?] => (simPhantom(e), sc) + case Negate(a) => (simValue(a).negate, sc) + case e: BinaryOpExpression[?] => (simBinOp(e), sc) + case ScalarProd(a, b) => (simVector(a).scale(simScalar(b)), sc) + case DotProd(a, b) => (simVector(a).dot(simVector(b)), sc) + case e: BitwiseOpExpression[?] => (simBitwiseOp(e), sc) + case e: ComparisonOpExpression[?] => (simCompareOp(e), sc) + case And(a, b) => (simScalar(a) && simScalar(b), sc) + case Or(a, b) => (simScalar(a) || simScalar(b), sc) + case Not(a) => (simScalar(a).negate, sc) + case ExtractScalar(a, i) => (simVector(a).apply(simValue(i).asInstanceOf[Int]), sc) + case e: ConvertExpression[?, ?] => (simConvert(e), sc) + case e: Const[?] => (simConst(e), sc) + case ComposeVec2(a, b) => (Vector(simScalar(a), simScalar(b)), sc) + case ComposeVec3(a, b, c) => (Vector(simScalar(a), simScalar(b), simScalar(c)), sc) + case ComposeVec4(a, b, c, d) => (Vector(simScalar(a), simScalar(b), simScalar(c), simScalar(d)), sc) + case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue), sc) + case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue), sc) + case InvocationId => ??? + case Pass(value) => ??? + case Dynamic(source) => ??? + case e: WhenExpr[?] => simWhen(e, sc) // returns new SimContext + case e: ReadBuffer[?] => simReadBuffer(e, sc) // returns new SimContext + case e: ReadUniform[?] => simReadUniform(e, sc) // returns new SimContext + case e: GArrayElem[?] => simGArrayElem(e) + case e: FoldSeq[?, ?] => simFoldSeq(e) + case e: ComposeStruct[?] => simComposeStruct(e) + case e: GetField[?, ?] => simGetField(e) + case _ => throw IllegalArgumentException("sim: wrong argument") + + private def simPhantom(e: PhantomExpression[?]): Result = e match + case CurrentElem(tid: Int) => ??? + case AggregateElem(tid: Int) => ??? + + private def simBinOp(e: BinaryOpExpression[?])(using exprMap: Map[Int, Result]): Result = e match + case Sum(a, b) => simValue(a).add(simValue(b)) // scalar or vector + case Diff(a, b) => simValue(a).sub(simValue(b)) // scalar or vector + case Mul(a, b) => simScalar(a).mul(simScalar(b)) + case Div(a, b) => simScalar(a).div(simScalar(b)) + case Mod(a, b) => simScalar(a).mod(simScalar(b)) + + private def simBitwiseOp(e: BitwiseOpExpression[?])(using exprMap: Map[Int, Result]): Int = e match + case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) + case BitwiseNot(a) => simScalar(a).bitNeg + case ShiftLeft(a, by) => simScalar(a).shiftLeft(simScalar(by)) + case ShiftRight(a, by) => simScalar(a).shiftRight(simScalar(by)) + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using exprMap: Map[Int, Result]): Int = e match + case BitwiseAnd(a, b) => simScalar(a).bitAnd(simScalar(b)) + case BitwiseOr(a, b) => simScalar(a).bitOr(simScalar(b)) + case BitwiseXor(a, b) => simScalar(a).bitXor(simScalar(b)) + + private def simCompareOp(e: ComparisonOpExpression[?])(using exprMap: Map[Int, Result]): Boolean = e match + case GreaterThan(a, b) => simScalar(a).gt(simScalar(b)) + case LessThan(a, b) => simScalar(a).lt(simScalar(b)) + case GreaterThanEqual(a, b) => simScalar(a).gteq(simScalar(b)) + case LessThanEqual(a, b) => simScalar(a).lteq(simScalar(b)) + case Equal(a, b) => simScalar(a).eql(simScalar(b)) + + private def simConvert(e: ConvertExpression[?, ?])(using exprMap: Map[Int, Result]): Float | Int = e match + case ToFloat32(a) => + exprMap(a.treeid) match + case f: Float => f + case _ => throw IllegalArgumentException("ToFloat32: wrong argument type") + case ToInt32(a) => + exprMap(a.treeid) match + case n: Int => n + case _ => throw IllegalArgumentException("ToInt32: wrong argument type") + case ToUInt32(a) => + exprMap(a.treeid) match + case n: Int => n + case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") + + private def simConst(e: Const[?]): ScalarRes = e match + case ConstFloat32(value) => value + case ConstInt32(value) => value + case ConstUInt32(value) => value + case ConstGB(value) => value + + private def simValue(v: Value)(using exprMap: Map[Int, Result]): Result = v match + case v: Scalar => simScalar(v) + case v: Vec[?] => simVector(v) + + private def simScalar(v: Scalar)(using exprMap: Map[Int, Result]): ScalarRes = v match + case v: FloatType => exprMap(v.tree.treeid).asInstanceOf[Float] + case v: IntType => exprMap(v.tree.treeid).asInstanceOf[Int] + case v: UIntType => exprMap(v.tree.treeid).asInstanceOf[Int] + case GBoolean(source) => exprMap(source.treeid).asInstanceOf[Boolean] + + private def simVector(v: Vec[?])(using exprMap: Map[Int, Result]): Vector[ScalarRes] = v match + case Vec2(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec3(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec4(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + + private def simExtFunc(fn: FunctionName, args: List[Result], sc: SimContext): (Result, SimContext) = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result], sc: SimContext): (Result, SimContext) = ??? + + @annotation.tailrec + private def whenHelper( + when: Expression[GBoolean], + thenCode: Scope[?], + otherConds: List[Scope[GBoolean]], + otherCaseCodes: List[Scope[?]], + otherwise: Scope[?], + sc: SimContext, + ): (Result, SimContext) = + // scopes are not included in the "main" exprMap, they have to be simulated from scratch. + // there could be reads/writes happening in scopes, SimContext has to be updated. + val (boolRes, newSc) = sim(when, sc) + if boolRes.asInstanceOf[Boolean] then sim(thenCode.expr, newSc) + else + otherConds.headOption match + case None => sim(otherwise.expr, newSc) + case Some(cond) => + whenHelper( + when = cond.expr, + thenCode = otherCaseCodes.head, + otherConds = otherConds.tail, + otherCaseCodes = otherCaseCodes.tail, + otherwise = otherwise, + sc = newSc, + ) + + private def simWhen(e: WhenExpr[?], sc: SimContext): (Result, SimContext) = e match + case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => + whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, sc) + + private def simReadBuffer(e: ReadBuffer[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = e match + case ReadBuffer(buffer, index) => + val i = exprMap(index.tree.treeid).asInstanceOf[Int] + val newSc = sc.addRead(ReadBuf(buffer, i)) + (newSc.lookup(buffer, i), newSc) + + private def simReadUniform(uni: ReadUniform[?], sc: SimContext): (Result, SimContext) = uni match + case ReadUniform(uniform) => + val newSc = sc.addRead(ReadUni(uniform)) + (newSc.lookupUni(uniform), newSc) + + private def simGArrayElem(gElem: GArrayElem[?]): (Result, SimContext) = gElem match + case GArrayElem(index, i) => ??? + + private def simFoldSeq(seq: FoldSeq[?, ?]): (Result, SimContext) = seq match + case FoldSeq(zero, fn, seq) => ??? + + private def simComposeStruct(cs: ComposeStruct[?]): (Result, SimContext) = cs match + case ComposeStruct(fields, resultSchema) => ??? + + private def simGetField(gf: GetField[?, ?]): (Result, SimContext) = gf match + case GetField(struct, fieldIndex) => ??? diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala new file mode 100644 index 00000000..dde9ce36 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala @@ -0,0 +1,23 @@ +package io.computenode.cyfra.interpreter + +object VectorResult: + import ScalarResult.* + + extension (v: Vector[ScalarRes]) + infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _) + infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _) + infix def eql(that: Vector[ScalarRes]): Boolean = v.zip(that).forall(_ === _) + infix def scale(s: ScalarRes) = v.map(_ * s) + + def sumRes: Float | Int = v.headOption match + case None => 0 + case Some(value) => + value match + case f: Float => v.asInstanceOf[Vector[Float]].sum + case n: Int => v.asInstanceOf[Vector[Int]].sum + case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans") + + infix def dot(that: Vector[ScalarRes]): Float | Int = v + .zip(that) + .map(_ * _) + .sumRes