From fb8147c1b0bef9637debb96ee75feb69a4707391 Mon Sep 17 00:00:00 2001 From: spamegg Date: Wed, 9 Jul 2025 17:07:36 +0300 Subject: [PATCH 01/13] interpreter --- .github/workflows/ci.yml | 1 + build.sbt | 8 +- .../cyfra/e2e/interpreter/SimulateTests.scala | 51 ++++++ .../cyfra/interpreter/Interpreter.scala | 51 ++++++ .../cyfra/interpreter/Result.scala | 102 +++++++++++ .../cyfra/interpreter/Simulate.scala | 171 ++++++++++++++++++ 6 files changed, 382 insertions(+), 2 deletions(-) create mode 100644 cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala create mode 100644 cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala create mode 100644 cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala create mode 100644 cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c47d94a8..64d5109b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,6 +3,7 @@ on: push: branches: - main + - dev tags: - "v*" pull_request: diff --git a/build.sbt b/build.sbt index 44acec5f..84e95b4e 100644 --- a/build.sbt +++ b/build.sbt @@ -96,13 +96,17 @@ lazy val vscode = (project in file("cyfra-vscode")) .settings(commonSettings) .dependsOn(foton) +lazy val interpreter = (project in file("cyfra-interpreter")) + .settings(commonSettings) + .dependsOn(core) + lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(commonSettings, runnerSettings) - .dependsOn(runtime) + .dependsOn(runtime, interpreter) lazy val root = (project in file(".")) .settings(name := "Cyfra") - .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples) + .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, interpreter) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala new file mode 100644 index 00000000..b1916bbe --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -0,0 +1,51 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import Value.FromExpr.fromExpr, control.Scope + +class SimulateE2eTest extends munit.FunSuite: + test("simulate binary operation arithmetic"): + val a: Int32 = 1 + val b: Int32 = 2 + val c: Int32 = 3 + val d: Int32 = 4 + val e: Int32 = 5 + val f: Int32 = 6 + val e1 = Diff(a, b) + val e2 = Sum(fromExpr(e1), c) + val e3 = Mul(f, fromExpr(e2)) + val e4 = Div(fromExpr(e3), d) + val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4) + val result = Simulate.sim(expr) + val expected = 2 + assert(result == expected, s"Expected $expected, got $result") + + test("simulate vec4, scalar, dot"): + val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) + val res1 = Simulate.sim(v1) + val exp1 = Vector(1f, 2f, 3f, 4f) + assert(res1 == exp1, s"Expected $exp1, got $res1") + + val v2 = ScalarProd(fromExpr(v1), -1f) + val res2 = Simulate.sim(v2) + val exp2 = Vector(-1f, -2f, -3f, -4f) + assert(res2 == exp2, s"Expected $exp2, got $res2") + + val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f) + val dot = DotProd(fromExpr(v1), fromExpr(v3)) + val res3 = Simulate.sim(dot).asInstanceOf[Float] + val exp3 = 0f + assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3") + + test("when test"): + val expr = WhenExpr( + when = 2 <= 1, + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 >= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val res = Simulate.sim(expr) + val exp = 3 + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala new file mode 100644 index 00000000..8e153a1b --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -0,0 +1,51 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given}, collections.GSeq +import binding.*, Value.*, gio.GIO, GIO.* +import izumi.reflect.Tag + +object Interpreter: + case class Write(buffer: GBuffer[?], index: Int, value: Any) + case class Read(buffer: GBuffer[?], index: Int) + case class InvocSimResult( + invocId: Int, + instructions: List[Expression[?]], + values: List[Any] = Nil, + writes: List[Write] = Nil, + reads: List[Read] = Nil, + ) + case class SimResult(invocs: List[InvocSimResult] = Nil): + def add(that: SimResult) = SimResult(that.invocs ::: invocs) + + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + + // val bufferA = SimGBuffer[Int32]() + // val gbuffers = Map[GBuffer[?], Array[Int32]](bufferA -> Array.fill(1024)(0)) + // val expression = bufferA.read(0) + 2 + // val res = Simulate.sim(expression, 1024) -> SimulationResult(???) + + def interpretPure(gio: Pure[?]): SimResult = gio match + case Pure(value) => + val id = Simulate.sim(invocationId).asInstanceOf[Int] + val invocSimRes = InvocSimResult(invocId = id, instructions = ???, values = ???) + SimResult(List(invocSimRes)) + + def interpretOne(gio: GIO[?]): SimResult = gio match + case Pure(value) => ??? + case WriteBuffer(buffer, index, value) => ??? + case WriteUniform(uniform, value) => ??? + case _ => throw IllegalArgumentException("interpret: invalid GIO") + + @annotation.tailrec + def interpretMany(gios: List[GIO[?]], simRes: SimResult): SimResult = gios match + case head :: tail => + head match + case FlatMap(gio, next) => interpretMany(gio :: next :: tail, simRes) + case Repeat(n, f) => + val int = Simulate.sim(n).asInstanceOf[Int] + val newGios = (0 until int).map(i => f(i)).toList + interpretMany(newGios ::: tail, simRes) + case _ => interpretMany(tail, simRes.add(interpretOne(head))) + case Nil => simRes + + def interpret(gio: GIO[?]): SimResult = interpretMany(List(gio), SimResult()) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala new file mode 100644 index 00000000..d5c79862 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala @@ -0,0 +1,102 @@ +package io.computenode.cyfra.interpreter + +object Result: + type ScalarRes = Float | Int | Boolean + type Result = ScalarRes | Vector[ScalarRes] + + extension (sr: ScalarRes) + def neg: ScalarRes = sr match + case f: Float => -f + case n: Int => -n + case b: Boolean => !b + + infix def +(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f + t + case (n: Int, t: Int) => n + t + case _ => throw IllegalArgumentException("+: incompatible argument types") + + infix def -(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f - t + case (n: Int, t: Int) => n - t + case _ => throw IllegalArgumentException("-: incompatible argument types") + + infix def *(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f * t + case (n: Int, t: Int) => n * t + case _ => throw IllegalArgumentException("*: incompatible argument types") + + infix def /(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f / t + case (n: Int, t: Int) => n / t + case _ => throw IllegalArgumentException("/: incompatible argument types") + + infix def %(that: ScalarRes) = (sr, that) match + case (n: Int, t: Int) => n % t + case _ => throw IllegalArgumentException("%: incompatible argument types") + + infix def &&(that: ScalarRes) = (sr, that) match + case (b: Boolean, t: Boolean) => b && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + infix def ||(that: ScalarRes) = (sr, that) match + case (b: Boolean, t: Boolean) => b || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + infix def >(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f > t + case (n: Int, t: Int) => n > t + case _ => throw IllegalArgumentException(">: incompatible argument types") + + infix def <(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f < t + case (n: Int, t: Int) => n < t + case _ => throw IllegalArgumentException("<: incompatible argument types") + + infix def >=(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f >= t + case (n: Int, t: Int) => n >= t + case _ => throw IllegalArgumentException(">=: incompatible argument types") + + infix def <=(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => f <= t + case (n: Int, t: Int) => n <= t + case _ => throw IllegalArgumentException("<=: incompatible argument types") + + infix def eql(that: ScalarRes) = (sr, that) match + case (f: Float, t: Float) => Math.abs(f - t) < 0.001f + case (n: Int, t: Int) => n == t + case (b: Boolean, t: Boolean) => b == t + case _ => throw IllegalArgumentException("<=: incompatible argument types") + + extension (v: Vector[ScalarRes]) + infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _) + infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _) + infix def scale(s: ScalarRes) = v.map(_ * s) + + def sumRes: ScalarRes = v.headOption match + case None => 0 + case Some(value) => + value match + case f: Float => v.asInstanceOf[Vector[Float]].sum + case n: Int => v.asInstanceOf[Vector[Int]].sum + case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans") + + infix def dot(that: Vector[ScalarRes]) = v + .zip(that) + .map(_ * _) + .sumRes + + extension (r: Result) + def negate: Result = r match + case s: ScalarRes => s.neg + case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd + + infix def add(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s + t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t + case _ => throw IllegalArgumentException("add: incompatible argument types") + + infix def sub(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s - t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t + case _ => throw IllegalArgumentException("sub: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala new file mode 100644 index 00000000..f82c80c0 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -0,0 +1,171 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, macros.FnCall.FnIdentifier, control.Scope +import collections.*, GArray.GArrayElem, GSeq.{CurrentElem, AggregateElem, FoldSeq} +import struct.*, GStruct.{ComposeStruct, GetField} + +object Simulate: + import Result.* + + def sim(v: Value): Result = sim(v.tree) // helpful wrapper for Value instead of Expression + + def sim(e: Expression[?]): Result = e match + case e: PhantomExpression[?] => simPhantom(e) + case Negate(a) => simValue(a).negate + case e: BinaryOpExpression[?] => simBinOp(e) + case ScalarProd(a, b) => simVector(a) scale simScalar(b) + case DotProd(a, b) => simVector(a) dot simVector(b) + case e: BitwiseOpExpression[?] => simBitwiseOp(e) + case e: ComparisonOpExpression[?] => simCompareOp(e) + case And(a, b) => simScalar(a) && simScalar(b) + case Or(a, b) => simScalar(a) || simScalar(b) + case Not(a) => simScalar(a).neg + case ExtractScalar(a, i) => ??? + case e: ConvertExpression[?, ?] => simConvert(e) + case e: Const[?] => simConst(e) + case ComposeVec2(a, b) => Vector(simScalar(a), simScalar(b)) + case ComposeVec3(a, b, c) => Vector(simScalar(a), simScalar(b), simScalar(c)) + case ComposeVec4(a, b, c, d) => Vector(simScalar(a), simScalar(b), simScalar(c), simScalar(d)) + case ExtFunctionCall(fn, args) => simExtFunc(fn, args.map(simValue)) + case FunctionCall(fn, body, args) => simFunc(fn, simScope(body), args.map(simValue)) + case InvocationId => ??? + case Pass(value) => ??? + case Dynamic(source) => ??? + case e: WhenExpr[?] => simWhen(e) + case e: ReadBuffer[?] => simReadBuffer(e) + case e: ReadUniform[?] => simReadUniform(e) + case e: GArrayElem[?] => simGArrayElem(e) + case e: FoldSeq[?, ?] => simFoldSeq(e) + case e: ComposeStruct[?] => ??? + case e: GetField[?, ?] => ??? + case _ => throw IllegalArgumentException("wrong argument") + + private def simPhantom(e: PhantomExpression[?]): Result = e match + case CurrentElem(tid: Int) => ??? + case AggregateElem(tid: Int) => ??? + + private def simBinOp(e: BinaryOpExpression[?]): Result = e match + case Sum(a, b) => simValue(a) add simValue(b) // scalar or vector + case Diff(a, b) => simValue(a) sub simValue(b) // scalar or vector + case Mul(a, b) => simScalar(a) * simScalar(b) + case Div(a, b) => simScalar(a) / simScalar(b) + case Mod(a, b) => simScalar(a) % simScalar(b) + + private def simBitwiseOp(e: BitwiseOpExpression[?]): Int = e match + case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) + case BitwiseNot(a) => + simScalar(a) match + case m: Int => ~m + case _ => throw IllegalArgumentException("BitwiseNot: wrong argument type") + case ShiftLeft(a, by) => + (simScalar(a), simScalar(by)) match + case (m: Int, n: Int) => m << n + case _ => throw IllegalArgumentException("ShiftLeft: wrong argument types") + case ShiftRight(a, by) => + (simScalar(a), simScalar(by)) match + case (m: Int, n: Int) => m >> n + case _ => throw IllegalArgumentException("ShiftRight: wrong argument types") + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?]) = e match + case BitwiseAnd(a, b) => + (simScalar(a), simScalar(b)) match + case (m: Int, n: Int) => m & n + case _ => throw IllegalArgumentException("BitwiseAnd: wrong argument types") + case BitwiseOr(a, b) => + (simScalar(a), simScalar(b)) match + case (m: Int, n: Int) => m | n + case _ => throw IllegalArgumentException("BitwiseOr: wrong argument types") + case BitwiseXor(a, b) => + (simScalar(a), simScalar(b)) match + case (m: Int, n: Int) => m ^ n + case _ => throw IllegalArgumentException("BitwiseXor: wrong argument types") + + private def simCompareOp(e: ComparisonOpExpression[?]): Boolean = e match + case GreaterThan(a, b) => simScalar(a) > simScalar(b) + case LessThan(a, b) => simScalar(a) < simScalar(b) + case GreaterThanEqual(a, b) => simScalar(a) >= simScalar(b) + case LessThanEqual(a, b) => simScalar(a) <= simScalar(b) + case Equal(a, b) => simScalar(a) eql simScalar(b) + + private def simConvert(e: ConvertExpression[?, ?]): Float | Int = e match + case ToFloat32(a) => + simScalar(a) match + case f: Float => f + case _ => throw IllegalArgumentException("ToFloat32: wrong argument type") + case ToInt32(a) => + simScalar(a) match + case n: Int => n + case _ => throw IllegalArgumentException("ToInt32: wrong argument type") + case ToUInt32(a) => + simScalar(a) match + case n: Int => n + case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") + + private def simConst(e: Const[?]): ScalarRes = e match + case ConstFloat32(value) => value + case ConstInt32(value) => value + case ConstUInt32(value) => value + case ConstGB(value) => value + + private def simValue(v: Value): Result = v match + case v: Scalar => simScalar(v) + case v: Vec[?] => simVector(v) + + private def simScalar(v: Scalar): ScalarRes = v match + case v: FloatType => sim(v.tree).asInstanceOf[Float] + case v: IntType => sim(v.tree).asInstanceOf[Int] + case v: UIntType => sim(v.tree).asInstanceOf[Int] + case GBoolean(source) => sim(source).asInstanceOf[Boolean] + + private def simVector(v: Vec[?]): Vector[ScalarRes] = v match + case Vec2(tree) => sim(tree).asInstanceOf[Vector[ScalarRes]] + case Vec3(tree) => sim(tree).asInstanceOf[Vector[ScalarRes]] + case Vec4(tree) => sim(tree).asInstanceOf[Vector[ScalarRes]] + + private def simExtFunc(fn: FunctionName, args: List[Result]): Result = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result]): Result = ??? + private def simScope(body: Scope[?]) = sim(body.expr) + + @annotation.tailrec + private def whenHelper( + when: GBoolean, + thenCode: Scope[?], + otherConds: List[Scope[GBoolean]], + otherCaseCodes: List[Scope[?]], + otherwise: Scope[?], + ): Result = + if sim(when.tree).asInstanceOf[Boolean] then sim(thenCode.expr) + else + otherConds.headOption match + case None => sim(otherwise.expr) + case Some(cond) => + whenHelper( + when = GBoolean(cond.expr), + thenCode = otherCaseCodes.head, + otherConds = otherConds.tail, + otherCaseCodes = otherCaseCodes.tail, + otherwise = otherwise, + ) + + private def simWhen(e: WhenExpr[?]): Result = e match + case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => + whenHelper(when, thenCode, otherConds, otherCaseCodes, otherwise) + + private def simReadBuffer(buf: ReadBuffer[?]): Result = buf match + case ReadBuffer(buffer, index) => ??? + + private def simReadUniform(uni: ReadUniform[?]): Result = uni match + case ReadUniform(uniform) => ??? + + private def simGArrayElem(gElem: GArrayElem[?]): Result = gElem match + case GArrayElem(index, i) => ??? + + private def simFoldSeq(seq: FoldSeq[?, ?]): Result = seq match + case FoldSeq(zero, fn, seq) => ??? + + private def simComposeStruct(cs: ComposeStruct[?]): Result = cs match + case ComposeStruct(fields, resultSchema) => ??? + + private def simGetField(gf: GetField[?, ?]): Result = gf match + case GetField(struct, fieldIndex) => ??? From e0b9b80e6a6c3487d73391da58222c530951cab2 Mon Sep 17 00:00:00 2001 From: spamegg Date: Fri, 11 Jul 2025 17:35:55 +0300 Subject: [PATCH 02/13] refactor Simulator to avoid recursion --- .../e2e/interpreter/InterpreterTests.scala | 11 ++ .../cyfra/e2e/interpreter/SimulateTests.scala | 25 ++- .../cyfra/interpreter/Interpreter.scala | 12 +- .../cyfra/interpreter/Result.scala | 162 ++++++++--------- .../cyfra/interpreter/ScalarResult.scala | 92 ++++++++++ .../cyfra/interpreter/Simulate.scala | 164 +++++++++--------- .../cyfra/interpreter/VectorResult.scala | 23 +++ 7 files changed, 317 insertions(+), 172 deletions(-) create mode 100644 cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala create mode 100644 cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala create mode 100644 cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala new file mode 100644 index 00000000..7b5f57a8 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import Value.FromExpr.fromExpr, control.Scope + +class InterpreterE2eTest extends munit.FunSuite: + test("stub"): + val res = 0 + val exp = 0 + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index b1916bbe..5f3a414e 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -38,7 +38,8 @@ class SimulateE2eTest extends munit.FunSuite: val exp3 = 0f assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3") - test("when test"): + // currently not working due to Scope + test("simulate when elseWhen otherwise".ignore): val expr = WhenExpr( when = 2 <= 1, thenCode = Scope(ConstInt32(1)), @@ -49,3 +50,25 @@ class SimulateE2eTest extends munit.FunSuite: val res = Simulate.sim(expr) val exp = 3 assert(res == exp, s"Expected $exp, got $res") + + test("simulate bitwise ops"): + val a: Int32 = 5 + val by: UInt32 = 3 + val aNot = BitwiseNot(a) + val left = ShiftLeft(fromExpr(aNot), by) + val right = ShiftRight(fromExpr(aNot), by) + val and = BitwiseAnd(fromExpr(left), fromExpr(right)) + val or = BitwiseOr(fromExpr(left), fromExpr(right)) + val xor = BitwiseXor(fromExpr(and), fromExpr(or)) + + val res = Simulate.sim(xor) + val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) + assert(res == exp, s"Expected $exp, got $res") + + test("simulate stack overflow"): + val a: Int32 = 1 + var sum = Sum(a, a) // 2 + for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) + val res = Simulate.sim(sum) + val exp = 1000002 + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 8e153a1b..c9473c3e 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.interpreter -import io.computenode.cyfra.dsl.{*, given}, collections.GSeq +import io.computenode.cyfra.dsl.{*, given} import binding.*, Value.*, gio.GIO, GIO.* import izumi.reflect.Tag @@ -9,7 +9,7 @@ object Interpreter: case class Read(buffer: GBuffer[?], index: Int) case class InvocSimResult( invocId: Int, - instructions: List[Expression[?]], + instructions: List[Expression[?]] = Nil, values: List[Any] = Nil, writes: List[Write] = Nil, reads: List[Read] = Nil, @@ -27,7 +27,13 @@ object Interpreter: def interpretPure(gio: Pure[?]): SimResult = gio match case Pure(value) => val id = Simulate.sim(invocationId).asInstanceOf[Int] - val invocSimRes = InvocSimResult(invocId = id, instructions = ???, values = ???) + val invocSimRes = InvocSimResult(invocId = id, values = List(value)) + SimResult(List(invocSimRes)) + + def interpretWriteBuffer(gio: Pure[?]): SimResult = gio match + case Pure(value) => + val id = Simulate.sim(invocationId).asInstanceOf[Int] + val invocSimRes = InvocSimResult(invocId = id, values = List(value)) SimResult(List(invocSimRes)) def interpretOne(gio: GIO[?]): SimResult = gio match diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala index d5c79862..23307374 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala @@ -1,102 +1,94 @@ package io.computenode.cyfra.interpreter object Result: - type ScalarRes = Float | Int | Boolean - type Result = ScalarRes | Vector[ScalarRes] + export ScalarResult.*, VectorResult.* - extension (sr: ScalarRes) - def neg: ScalarRes = sr match - case f: Float => -f - case n: Int => -n - case b: Boolean => !b - - infix def +(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f + t - case (n: Int, t: Int) => n + t - case _ => throw IllegalArgumentException("+: incompatible argument types") - - infix def -(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f - t - case (n: Int, t: Int) => n - t - case _ => throw IllegalArgumentException("-: incompatible argument types") - - infix def *(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f * t - case (n: Int, t: Int) => n * t - case _ => throw IllegalArgumentException("*: incompatible argument types") - - infix def /(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f / t - case (n: Int, t: Int) => n / t - case _ => throw IllegalArgumentException("/: incompatible argument types") - - infix def %(that: ScalarRes) = (sr, that) match - case (n: Int, t: Int) => n % t - case _ => throw IllegalArgumentException("%: incompatible argument types") - - infix def &&(that: ScalarRes) = (sr, that) match - case (b: Boolean, t: Boolean) => b && t - case _ => throw IllegalArgumentException("&&: incompatible argument types") - - infix def ||(that: ScalarRes) = (sr, that) match - case (b: Boolean, t: Boolean) => b || t - case _ => throw IllegalArgumentException("||: incompatible argument types") - - infix def >(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f > t - case (n: Int, t: Int) => n > t - case _ => throw IllegalArgumentException(">: incompatible argument types") - - infix def <(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f < t - case (n: Int, t: Int) => n < t - case _ => throw IllegalArgumentException("<: incompatible argument types") - - infix def >=(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f >= t - case (n: Int, t: Int) => n >= t - case _ => throw IllegalArgumentException(">=: incompatible argument types") - - infix def <=(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => f <= t - case (n: Int, t: Int) => n <= t - case _ => throw IllegalArgumentException("<=: incompatible argument types") - - infix def eql(that: ScalarRes) = (sr, that) match - case (f: Float, t: Float) => Math.abs(f - t) < 0.001f - case (n: Int, t: Int) => n == t - case (b: Boolean, t: Boolean) => b == t - case _ => throw IllegalArgumentException("<=: incompatible argument types") - - extension (v: Vector[ScalarRes]) - infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _) - infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _) - infix def scale(s: ScalarRes) = v.map(_ * s) - - def sumRes: ScalarRes = v.headOption match - case None => 0 - case Some(value) => - value match - case f: Float => v.asInstanceOf[Vector[Float]].sum - case n: Int => v.asInstanceOf[Vector[Int]].sum - case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans") - - infix def dot(that: Vector[ScalarRes]) = v - .zip(that) - .map(_ * _) - .sumRes + type Result = ScalarRes | Vector[ScalarRes] extension (r: Result) def negate: Result = r match case s: ScalarRes => s.neg case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd - infix def add(that: Result): Result = (r, that) match + def bitNeg: Int = r match + case sr: ScalarRes => ~sr + case _ => throw IllegalArgumentException("bitwiseNeg: wrong argument type") + + def shiftLeft(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n << b + case _ => throw IllegalArgumentException("shiftLeft: incompatible argument types") + + def shiftRight(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n >> b + case _ => throw IllegalArgumentException("shiftRight: incompatible argument types") + + def bitAnd(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s & t + case _ => throw IllegalArgumentException("bitAnd: incompatible argument types") + + def bitOr(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s | t + case _ => throw IllegalArgumentException("bitOr: incompatible argument types") + + def bitXor(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s ^ t + case _ => throw IllegalArgumentException("bitXor: incompatible argument types") + + def add(that: Result): Result = (r, that) match case (s: ScalarRes, t: ScalarRes) => s + t case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t case _ => throw IllegalArgumentException("add: incompatible argument types") - infix def sub(that: Result): Result = (r, that) match + def sub(that: Result): Result = (r, that) match case (s: ScalarRes, t: ScalarRes) => s - t case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t case _ => throw IllegalArgumentException("sub: incompatible argument types") + + def mul(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s * t + case _ => throw IllegalArgumentException("mul: incompatible argument types") + + def div(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s / t + case _ => throw IllegalArgumentException("div: incompatible argument types") + + def mod(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s % t + case _ => throw IllegalArgumentException("mod: incompatible argument types") + + def scale(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: ScalarRes) => v scale t + case _ => throw IllegalArgumentException("scale: incompatible argument types") + + def dot(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v dot t + case _ => throw IllegalArgumentException("dot: incompatible argument types") + + def &&(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + def ||(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + def gt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr > t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def lt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr < t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def gteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr >= t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def lteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr <= t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def eql(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr === t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t + case _ => throw IllegalArgumentException("gt: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala new file mode 100644 index 00000000..68683ddd --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala @@ -0,0 +1,92 @@ +package io.computenode.cyfra.interpreter + +object ScalarResult: + type ScalarRes = Float | Int | Boolean + + extension (sr: ScalarRes) + def neg: ScalarRes = sr match + case f: Float => -f + case n: Int => -n + case b: Boolean => !b + + infix def unary_~ : Int = sr match + case n: Int => ~n + case _ => throw IllegalArgumentException("bitNeg: wrong argument type") + + infix def <<(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n << b + case _ => throw IllegalArgumentException("<<: incompatible argument types") + + infix def >>(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n >> b + case _ => throw IllegalArgumentException(">>: incompatible argument types") + + infix def &(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m & n + case _ => throw IllegalArgumentException("&: incompatible argument types") + + infix def |(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m | n + case _ => throw IllegalArgumentException("|: incompatible argument types") + + infix def ^(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m ^ n + case _ => throw IllegalArgumentException("^: incompatible argument types") + + infix def +(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f + t + case (n: Int, t: Int) => n + t + case _ => throw IllegalArgumentException("+: incompatible argument types") + + infix def -(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f - t + case (n: Int, t: Int) => n - t + case _ => throw IllegalArgumentException("-: incompatible argument types") + + infix def *(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f * t + case (n: Int, t: Int) => n * t + case _ => throw IllegalArgumentException("*: incompatible argument types") + + infix def /(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f / t + case (n: Int, t: Int) => n / t + case _ => throw IllegalArgumentException("/: incompatible argument types") + + infix def %(that: ScalarRes): Int = (sr, that) match + case (n: Int, t: Int) => n % t + case _ => throw IllegalArgumentException("%: incompatible argument types") + + infix def &&(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + infix def ||(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + infix def >(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f > t + case (n: Int, t: Int) => n > t + case _ => throw IllegalArgumentException(">: incompatible argument types") + + infix def <(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f < t + case (n: Int, t: Int) => n < t + case _ => throw IllegalArgumentException("<: incompatible argument types") + + infix def >=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f >= t + case (n: Int, t: Int) => n >= t + case _ => throw IllegalArgumentException(">=: incompatible argument types") + + infix def <=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f <= t + case (n: Int, t: Int) => n <= t + case _ => throw IllegalArgumentException("<=: incompatible argument types") + + infix def ===(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => Math.abs(f - t) < 0.001f + case (n: Int, t: Int) => n == t + case (b: Boolean, t: Boolean) => b == t + case _ => throw IllegalArgumentException("eqls: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index f82c80c0..ba16b3ee 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -4,31 +4,47 @@ import io.computenode.cyfra.dsl.{*, given} import binding.*, macros.FnCall.FnIdentifier, control.Scope import collections.*, GArray.GArrayElem, GSeq.{CurrentElem, AggregateElem, FoldSeq} import struct.*, GStruct.{ComposeStruct, GetField} +import io.computenode.cyfra.spirv.BlockBuilder.buildBlock +import collection.mutable.Map as MMap object Simulate: import Result.* def sim(v: Value): Result = sim(v.tree) // helpful wrapper for Value instead of Expression - def sim(e: Expression[?]): Result = e match + def sim(e: Expression[?]): Result = + val exprMap = MMap.empty[Int, Result] // treeid of Expr -> result of evaluating that Expr + val blocks = buildBlock(e) + simIterate(blocks)(using exprMap) + + @annotation.tailrec + def simIterate(blocks: List[Expression[?]])(using exprMap: MMap[Int, Result]): Result = blocks match + case head :: Nil => simOne(head) + case head :: next => + val result = simOne(head) + exprMap.addOne(head.treeid -> result) + simIterate(next) + case Nil => ??? // should not happen + + def simOne(e: Expression[?])(using exprMap: MMap[Int, Result]): Result = e match case e: PhantomExpression[?] => simPhantom(e) case Negate(a) => simValue(a).negate case e: BinaryOpExpression[?] => simBinOp(e) - case ScalarProd(a, b) => simVector(a) scale simScalar(b) - case DotProd(a, b) => simVector(a) dot simVector(b) + case ScalarProd(a, b) => simVector(a).scale(simScalar(b)) + case DotProd(a, b) => simVector(a).dot(simVector(b)) case e: BitwiseOpExpression[?] => simBitwiseOp(e) case e: ComparisonOpExpression[?] => simCompareOp(e) case And(a, b) => simScalar(a) && simScalar(b) case Or(a, b) => simScalar(a) || simScalar(b) - case Not(a) => simScalar(a).neg - case ExtractScalar(a, i) => ??? + case Not(a) => simScalar(a).negate + case ExtractScalar(a, i) => ??? // simVector(a), simConst(i.tree) case e: ConvertExpression[?, ?] => simConvert(e) case e: Const[?] => simConst(e) case ComposeVec2(a, b) => Vector(simScalar(a), simScalar(b)) case ComposeVec3(a, b, c) => Vector(simScalar(a), simScalar(b), simScalar(c)) case ComposeVec4(a, b, c, d) => Vector(simScalar(a), simScalar(b), simScalar(c), simScalar(d)) - case ExtFunctionCall(fn, args) => simExtFunc(fn, args.map(simValue)) - case FunctionCall(fn, body, args) => simFunc(fn, simScope(body), args.map(simValue)) + case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue)) + case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue)) case InvocationId => ??? case Pass(value) => ??? case Dynamic(source) => ??? @@ -37,135 +53,117 @@ object Simulate: case e: ReadUniform[?] => simReadUniform(e) case e: GArrayElem[?] => simGArrayElem(e) case e: FoldSeq[?, ?] => simFoldSeq(e) - case e: ComposeStruct[?] => ??? - case e: GetField[?, ?] => ??? - case _ => throw IllegalArgumentException("wrong argument") + case e: ComposeStruct[?] => simComposeStruct(e) + case e: GetField[?, ?] => simGetField(e) + case _ => throw IllegalArgumentException("sim: wrong argument") - private def simPhantom(e: PhantomExpression[?]): Result = e match + private def simPhantom(e: PhantomExpression[?])(using exprMap: MMap[Int, Result]): Result = e match case CurrentElem(tid: Int) => ??? case AggregateElem(tid: Int) => ??? - private def simBinOp(e: BinaryOpExpression[?]): Result = e match - case Sum(a, b) => simValue(a) add simValue(b) // scalar or vector - case Diff(a, b) => simValue(a) sub simValue(b) // scalar or vector - case Mul(a, b) => simScalar(a) * simScalar(b) - case Div(a, b) => simScalar(a) / simScalar(b) - case Mod(a, b) => simScalar(a) % simScalar(b) + private def simBinOp(e: BinaryOpExpression[?])(using exprMap: MMap[Int, Result]): Result = e match + case Sum(a, b) => simValue(a).add(simValue(b)) // scalar or vector + case Diff(a, b) => simValue(a).sub(simValue(b)) // scalar or vector + case Mul(a, b) => simScalar(a).mul(simScalar(b)) + case Div(a, b) => simScalar(a).div(simScalar(b)) + case Mod(a, b) => simScalar(a).mod(simScalar(b)) - private def simBitwiseOp(e: BitwiseOpExpression[?]): Int = e match + private def simBitwiseOp(e: BitwiseOpExpression[?])(using exprMap: MMap[Int, Result]): Int = e match case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) - case BitwiseNot(a) => - simScalar(a) match - case m: Int => ~m - case _ => throw IllegalArgumentException("BitwiseNot: wrong argument type") - case ShiftLeft(a, by) => - (simScalar(a), simScalar(by)) match - case (m: Int, n: Int) => m << n - case _ => throw IllegalArgumentException("ShiftLeft: wrong argument types") - case ShiftRight(a, by) => - (simScalar(a), simScalar(by)) match - case (m: Int, n: Int) => m >> n - case _ => throw IllegalArgumentException("ShiftRight: wrong argument types") - - private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?]) = e match - case BitwiseAnd(a, b) => - (simScalar(a), simScalar(b)) match - case (m: Int, n: Int) => m & n - case _ => throw IllegalArgumentException("BitwiseAnd: wrong argument types") - case BitwiseOr(a, b) => - (simScalar(a), simScalar(b)) match - case (m: Int, n: Int) => m | n - case _ => throw IllegalArgumentException("BitwiseOr: wrong argument types") - case BitwiseXor(a, b) => - (simScalar(a), simScalar(b)) match - case (m: Int, n: Int) => m ^ n - case _ => throw IllegalArgumentException("BitwiseXor: wrong argument types") - - private def simCompareOp(e: ComparisonOpExpression[?]): Boolean = e match - case GreaterThan(a, b) => simScalar(a) > simScalar(b) - case LessThan(a, b) => simScalar(a) < simScalar(b) - case GreaterThanEqual(a, b) => simScalar(a) >= simScalar(b) - case LessThanEqual(a, b) => simScalar(a) <= simScalar(b) - case Equal(a, b) => simScalar(a) eql simScalar(b) - - private def simConvert(e: ConvertExpression[?, ?]): Float | Int = e match + case BitwiseNot(a) => simScalar(a).bitNeg + case ShiftLeft(a, by) => simScalar(a).shiftLeft(simScalar(by)) + case ShiftRight(a, by) => simScalar(a).shiftRight(simScalar(by)) + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using exprMap: MMap[Int, Result]): Int = e match + case BitwiseAnd(a, b) => simScalar(a).bitAnd(simScalar(b)) + case BitwiseOr(a, b) => simScalar(a).bitOr(simScalar(b)) + case BitwiseXor(a, b) => simScalar(a).bitXor(simScalar(b)) + + private def simCompareOp(e: ComparisonOpExpression[?])(using exprMap: MMap[Int, Result]): Boolean = e match + case GreaterThan(a, b) => simScalar(a).gt(simScalar(b)) + case LessThan(a, b) => simScalar(a).lt(simScalar(b)) + case GreaterThanEqual(a, b) => simScalar(a).gteq(simScalar(b)) + case LessThanEqual(a, b) => simScalar(a).lteq(simScalar(b)) + case Equal(a, b) => simScalar(a).eql(simScalar(b)) + + private def simConvert(e: ConvertExpression[?, ?])(using exprMap: MMap[Int, Result]): Float | Int = e match case ToFloat32(a) => - simScalar(a) match + exprMap(a.treeid) match case f: Float => f case _ => throw IllegalArgumentException("ToFloat32: wrong argument type") case ToInt32(a) => - simScalar(a) match + exprMap(a.treeid) match case n: Int => n case _ => throw IllegalArgumentException("ToInt32: wrong argument type") case ToUInt32(a) => - simScalar(a) match + exprMap(a.treeid) match case n: Int => n case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") - private def simConst(e: Const[?]): ScalarRes = e match + private def simConst(e: Const[?])(using exprMap: MMap[Int, Result]): ScalarRes = e match case ConstFloat32(value) => value case ConstInt32(value) => value case ConstUInt32(value) => value case ConstGB(value) => value - private def simValue(v: Value): Result = v match + private def simValue(v: Value)(using exprMap: MMap[Int, Result]): Result = v match case v: Scalar => simScalar(v) case v: Vec[?] => simVector(v) - private def simScalar(v: Scalar): ScalarRes = v match - case v: FloatType => sim(v.tree).asInstanceOf[Float] - case v: IntType => sim(v.tree).asInstanceOf[Int] - case v: UIntType => sim(v.tree).asInstanceOf[Int] - case GBoolean(source) => sim(source).asInstanceOf[Boolean] + private def simScalar(v: Scalar)(using exprMap: MMap[Int, Result]): ScalarRes = v match + case v: FloatType => exprMap(v.tree.treeid).asInstanceOf[Float] + case v: IntType => exprMap(v.tree.treeid).asInstanceOf[Int] + case v: UIntType => exprMap(v.tree.treeid).asInstanceOf[Int] + case GBoolean(source) => exprMap(source.treeid).asInstanceOf[Boolean] - private def simVector(v: Vec[?]): Vector[ScalarRes] = v match - case Vec2(tree) => sim(tree).asInstanceOf[Vector[ScalarRes]] - case Vec3(tree) => sim(tree).asInstanceOf[Vector[ScalarRes]] - case Vec4(tree) => sim(tree).asInstanceOf[Vector[ScalarRes]] + private def simVector(v: Vec[?])(using exprMap: MMap[Int, Result]) = v match + case Vec2(tree) => exprMap(v.tree.treeid) + case Vec3(tree) => exprMap(v.tree.treeid) + case Vec4(tree) => exprMap(v.tree.treeid) - private def simExtFunc(fn: FunctionName, args: List[Result]): Result = ??? - private def simFunc(fn: FnIdentifier, body: Result, args: List[Result]): Result = ??? - private def simScope(body: Scope[?]) = sim(body.expr) + private def simExtFunc(fn: FunctionName, args: List[Result])(using exprMap: MMap[Int, Result]): Result = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result])(using exprMap: MMap[Int, Result]): Result = ??? + private def simScope(body: Scope[?])(using exprMap: MMap[Int, Result]) = exprMap(body.expr.treeid) @annotation.tailrec private def whenHelper( - when: GBoolean, + when: Expression[GBoolean], thenCode: Scope[?], otherConds: List[Scope[GBoolean]], otherCaseCodes: List[Scope[?]], otherwise: Scope[?], - ): Result = - if sim(when.tree).asInstanceOf[Boolean] then sim(thenCode.expr) + )(using exprMap: MMap[Int, Result]): Result = + if exprMap(when.treeid).asInstanceOf[Boolean] then sim(thenCode.expr) else otherConds.headOption match - case None => sim(otherwise.expr) + case None => exprMap(otherwise.expr.treeid) case Some(cond) => whenHelper( - when = GBoolean(cond.expr), + when = cond.expr, thenCode = otherCaseCodes.head, otherConds = otherConds.tail, otherCaseCodes = otherCaseCodes.tail, otherwise = otherwise, ) - private def simWhen(e: WhenExpr[?]): Result = e match + private def simWhen(e: WhenExpr[?])(using exprMap: MMap[Int, Result]): Result = e match case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => - whenHelper(when, thenCode, otherConds, otherCaseCodes, otherwise) + whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise) - private def simReadBuffer(buf: ReadBuffer[?]): Result = buf match + private def simReadBuffer(buf: ReadBuffer[?])(using exprMap: MMap[Int, Result]): Result = buf match case ReadBuffer(buffer, index) => ??? - private def simReadUniform(uni: ReadUniform[?]): Result = uni match + private def simReadUniform(uni: ReadUniform[?])(using exprMap: MMap[Int, Result]): Result = uni match case ReadUniform(uniform) => ??? - private def simGArrayElem(gElem: GArrayElem[?]): Result = gElem match + private def simGArrayElem(gElem: GArrayElem[?])(using exprMap: MMap[Int, Result]): Result = gElem match case GArrayElem(index, i) => ??? - private def simFoldSeq(seq: FoldSeq[?, ?]): Result = seq match + private def simFoldSeq(seq: FoldSeq[?, ?])(using exprMap: MMap[Int, Result]): Result = seq match case FoldSeq(zero, fn, seq) => ??? - private def simComposeStruct(cs: ComposeStruct[?]): Result = cs match + private def simComposeStruct(cs: ComposeStruct[?])(using exprMap: MMap[Int, Result]): Result = cs match case ComposeStruct(fields, resultSchema) => ??? - private def simGetField(gf: GetField[?, ?]): Result = gf match + private def simGetField(gf: GetField[?, ?])(using exprMap: MMap[Int, Result]): Result = gf match case GetField(struct, fieldIndex) => ??? diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala new file mode 100644 index 00000000..dde9ce36 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala @@ -0,0 +1,23 @@ +package io.computenode.cyfra.interpreter + +object VectorResult: + import ScalarResult.* + + extension (v: Vector[ScalarRes]) + infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _) + infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _) + infix def eql(that: Vector[ScalarRes]): Boolean = v.zip(that).forall(_ === _) + infix def scale(s: ScalarRes) = v.map(_ * s) + + def sumRes: Float | Int = v.headOption match + case None => 0 + case Some(value) => + value match + case f: Float => v.asInstanceOf[Vector[Float]].sum + case n: Int => v.asInstanceOf[Vector[Int]].sum + case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans") + + infix def dot(that: Vector[ScalarRes]): Float | Int = v + .zip(that) + .map(_ * _) + .sumRes From feed0e4d8eb8f896eaae99b88090f914f23588b4 Mon Sep 17 00:00:00 2001 From: spamegg Date: Sat, 12 Jul 2025 13:47:27 +0300 Subject: [PATCH 03/13] handle ExtractScalar, some cleanup --- .../cyfra/e2e/interpreter/SimulateTests.scala | 35 +++++++++++-------- .../cyfra/interpreter/Interpreter.scala | 34 +++++++++--------- .../cyfra/interpreter/Result.scala | 10 +++--- .../cyfra/interpreter/ScalarResult.scala | 4 +-- .../cyfra/interpreter/Simulate.scala | 16 ++++----- 5 files changed, 51 insertions(+), 48 deletions(-) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index 5f3a414e..85005170 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -21,12 +21,17 @@ class SimulateE2eTest extends munit.FunSuite: val expected = 2 assert(result == expected, s"Expected $expected, got $result") - test("simulate vec4, scalar, dot"): + test("simulate vec4, scalar, dot, extract scalar"): val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) val res1 = Simulate.sim(v1) val exp1 = Vector(1f, 2f, 3f, 4f) assert(res1 == exp1, s"Expected $exp1, got $res1") + val i: Int32 = 2 + val res = Simulate.sim(ExtractScalar(fromExpr(v1), i)) + val exp = 3f + assert(res == exp, s"Expected $exp, got $res") + val v2 = ScalarProd(fromExpr(v1), -1f) val res2 = Simulate.sim(v2) val exp2 = Vector(-1f, -2f, -3f, -4f) @@ -38,19 +43,6 @@ class SimulateE2eTest extends munit.FunSuite: val exp3 = 0f assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3") - // currently not working due to Scope - test("simulate when elseWhen otherwise".ignore): - val expr = WhenExpr( - when = 2 <= 1, - thenCode = Scope(ConstInt32(1)), - otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 >= 3))), - otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), - otherwise = Scope(ConstInt32(3)), - ) - val res = Simulate.sim(expr) - val exp = 3 - assert(res == exp, s"Expected $exp, got $res") - test("simulate bitwise ops"): val a: Int32 = 5 val by: UInt32 = 3 @@ -65,10 +57,23 @@ class SimulateE2eTest extends munit.FunSuite: val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) assert(res == exp, s"Expected $exp, got $res") - test("simulate stack overflow"): + test("simulate stack overflow".ignore): val a: Int32 = 1 var sum = Sum(a, a) // 2 for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) val res = Simulate.sim(sum) val exp = 1000002 assert(res == exp, s"Expected $exp, got $res") + + // currently not working due to Scope + test("simulate when elseWhen otherwise".ignore): + val expr = WhenExpr( + when = 2 <= 1, + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 >= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val res = Simulate.sim(expr) + val exp = 3 + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index c9473c3e..48f07619 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -30,28 +30,26 @@ object Interpreter: val invocSimRes = InvocSimResult(invocId = id, values = List(value)) SimResult(List(invocSimRes)) - def interpretWriteBuffer(gio: Pure[?]): SimResult = gio match - case Pure(value) => - val id = Simulate.sim(invocationId).asInstanceOf[Int] - val invocSimRes = InvocSimResult(invocId = id, values = List(value)) - SimResult(List(invocSimRes)) + def interpretWriteBuffer(gio: WriteBuffer[?]): SimResult = gio match + case WriteBuffer(buffer, index, value) => ??? + + def interpretWriteUniform(gio: WriteUniform[?]): SimResult = gio match + case WriteUniform(uniform, value) => ??? def interpretOne(gio: GIO[?]): SimResult = gio match - case Pure(value) => ??? - case WriteBuffer(buffer, index, value) => ??? - case WriteUniform(uniform, value) => ??? - case _ => throw IllegalArgumentException("interpret: invalid GIO") + case p: Pure[?] => interpretPure(p) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb) + case wu: WriteUniform[?] => interpretWriteUniform(wu) + case _ => throw IllegalArgumentException("interpret: invalid GIO") @annotation.tailrec def interpretMany(gios: List[GIO[?]], simRes: SimResult): SimResult = gios match - case head :: tail => - head match - case FlatMap(gio, next) => interpretMany(gio :: next :: tail, simRes) - case Repeat(n, f) => - val int = Simulate.sim(n).asInstanceOf[Int] - val newGios = (0 until int).map(i => f(i)).toList - interpretMany(newGios ::: tail, simRes) - case _ => interpretMany(tail, simRes.add(interpretOne(head))) - case Nil => simRes + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, simRes) + case Repeat(n, f) :: tail => + val int = Simulate.sim(n).asInstanceOf[Int] + val newGios = (0 until int).map(i => f(i)).toList + interpretMany(newGios ::: tail, simRes) + case head :: tail => interpretMany(tail, simRes.add(interpretOne(head))) + case Nil => simRes def interpret(gio: GIO[?]): SimResult = interpretMany(List(gio), SimResult()) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala index 23307374..ad3d7634 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala @@ -12,7 +12,7 @@ object Result: def bitNeg: Int = r match case sr: ScalarRes => ~sr - case _ => throw IllegalArgumentException("bitwiseNeg: wrong argument type") + case _ => throw IllegalArgumentException("bitNeg: wrong argument type") def shiftLeft(by: Result): Int = (r, by) match case (n: ScalarRes, b: ScalarRes) => n << b @@ -78,17 +78,17 @@ object Result: def lt(that: Result): Boolean = (r, that) match case (sr: ScalarRes, t: ScalarRes) => sr < t - case _ => throw IllegalArgumentException("gt: incompatible argument types") + case _ => throw IllegalArgumentException("lt: incompatible argument types") def gteq(that: Result): Boolean = (r, that) match case (sr: ScalarRes, t: ScalarRes) => sr >= t - case _ => throw IllegalArgumentException("gt: incompatible argument types") + case _ => throw IllegalArgumentException("gteq: incompatible argument types") def lteq(that: Result): Boolean = (r, that) match case (sr: ScalarRes, t: ScalarRes) => sr <= t - case _ => throw IllegalArgumentException("gt: incompatible argument types") + case _ => throw IllegalArgumentException("lteq: incompatible argument types") def eql(that: Result): Boolean = (r, that) match case (sr: ScalarRes, t: ScalarRes) => sr === t case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t - case _ => throw IllegalArgumentException("gt: incompatible argument types") + case _ => throw IllegalArgumentException("eql: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala index 68683ddd..8d3e537f 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala @@ -11,7 +11,7 @@ object ScalarResult: infix def unary_~ : Int = sr match case n: Int => ~n - case _ => throw IllegalArgumentException("bitNeg: wrong argument type") + case _ => throw IllegalArgumentException("~: wrong argument type") infix def <<(by: ScalarRes): Int = (sr, by) match case (n: Int, b: Int) => n << b @@ -89,4 +89,4 @@ object ScalarResult: case (f: Float, t: Float) => Math.abs(f - t) < 0.001f case (n: Int, t: Int) => n == t case (b: Boolean, t: Boolean) => b == t - case _ => throw IllegalArgumentException("eqls: incompatible argument types") + case _ => throw IllegalArgumentException("===: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index ba16b3ee..1ce84ee0 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -37,7 +37,7 @@ object Simulate: case And(a, b) => simScalar(a) && simScalar(b) case Or(a, b) => simScalar(a) || simScalar(b) case Not(a) => simScalar(a).negate - case ExtractScalar(a, i) => ??? // simVector(a), simConst(i.tree) + case ExtractScalar(a, i) => simVector(a).apply(simValue(i).asInstanceOf[Int]) case e: ConvertExpression[?, ?] => simConvert(e) case e: Const[?] => simConst(e) case ComposeVec2(a, b) => Vector(simScalar(a), simScalar(b)) @@ -100,7 +100,7 @@ object Simulate: case n: Int => n case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") - private def simConst(e: Const[?])(using exprMap: MMap[Int, Result]): ScalarRes = e match + private def simConst(e: Const[?]): ScalarRes = e match case ConstFloat32(value) => value case ConstInt32(value) => value case ConstUInt32(value) => value @@ -117,13 +117,13 @@ object Simulate: case GBoolean(source) => exprMap(source.treeid).asInstanceOf[Boolean] private def simVector(v: Vec[?])(using exprMap: MMap[Int, Result]) = v match - case Vec2(tree) => exprMap(v.tree.treeid) - case Vec3(tree) => exprMap(v.tree.treeid) - case Vec4(tree) => exprMap(v.tree.treeid) + case Vec2(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec3(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec4(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] private def simExtFunc(fn: FunctionName, args: List[Result])(using exprMap: MMap[Int, Result]): Result = ??? private def simFunc(fn: FnIdentifier, body: Result, args: List[Result])(using exprMap: MMap[Int, Result]): Result = ??? - private def simScope(body: Scope[?])(using exprMap: MMap[Int, Result]) = exprMap(body.expr.treeid) + private def simScope(body: Scope[?])(using exprMap: MMap[Int, Result]) = exprMap(body.rootTreeId) @annotation.tailrec private def whenHelper( @@ -133,10 +133,10 @@ object Simulate: otherCaseCodes: List[Scope[?]], otherwise: Scope[?], )(using exprMap: MMap[Int, Result]): Result = - if exprMap(when.treeid).asInstanceOf[Boolean] then sim(thenCode.expr) + if exprMap(when.treeid).asInstanceOf[Boolean] then exprMap(thenCode.expr.treeid) else otherConds.headOption match - case None => exprMap(otherwise.expr.treeid) + case None => exprMap(otherwise.rootTreeId) case Some(cond) => whenHelper( when = cond.expr, From 88dda4485561282d18bd14101efdabdde857b6e4 Mon Sep 17 00:00:00 2001 From: spamegg Date: Sun, 13 Jul 2025 17:58:27 +0300 Subject: [PATCH 04/13] simulate when expressions and scope --- .../cyfra/e2e/interpreter/SimulateTests.scala | 15 +----- .../e2e/interpreter/SimulateWhenTests.scala | 54 +++++++++++++++++++ .../cyfra/interpreter/Simulate.scala | 4 +- 3 files changed, 57 insertions(+), 16 deletions(-) create mode 100644 cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index 85005170..b8dcfe1d 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -57,23 +57,10 @@ class SimulateE2eTest extends munit.FunSuite: val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) assert(res == exp, s"Expected $exp, got $res") - test("simulate stack overflow".ignore): + test("simulate should not stack overflow"): val a: Int32 = 1 var sum = Sum(a, a) // 2 for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) val res = Simulate.sim(sum) val exp = 1000002 assert(res == exp, s"Expected $exp, got $res") - - // currently not working due to Scope - test("simulate when elseWhen otherwise".ignore): - val expr = WhenExpr( - when = 2 <= 1, - thenCode = Scope(ConstInt32(1)), - otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 >= 3))), - otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), - otherwise = Scope(ConstInt32(3)), - ) - val res = Simulate.sim(expr) - val exp = 3 - assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala new file mode 100644 index 00000000..a2dba264 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -0,0 +1,54 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import Value.FromExpr.fromExpr, control.Scope + +class SimulateWhenE2eTest extends munit.FunSuite: + test("simulate when"): + val expr1 = WhenExpr( + when = 2 >= 1, // true + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val res1 = Simulate.sim(expr1) + val exp1 = 1 + assert(res1 == exp1, s"Expected $exp1, got $res1") + + test("simulate elseWhen first"): + val expr2 = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 >= 2)) /*true*/, Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val res2 = Simulate.sim(expr2) + val exp2 = 2 + assert(res2 == exp2, s"Expected $exp2, got $res2") + + test("simulate elseWhen second"): + val expr3 = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 <= 3))), // true + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val res3 = Simulate.sim(expr3) + val exp3 = 4 + assert(res3 == exp3, s"Expected $exp3, got $res3") + + test("simulate otherwise"): + val expr4 = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 >= 3))), // false + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val res4 = Simulate.sim(expr4) + val exp4 = 3 + assert(res4 == exp4, s"Expected $exp4, got $res4") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 1ce84ee0..37dee1c8 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -133,10 +133,10 @@ object Simulate: otherCaseCodes: List[Scope[?]], otherwise: Scope[?], )(using exprMap: MMap[Int, Result]): Result = - if exprMap(when.treeid).asInstanceOf[Boolean] then exprMap(thenCode.expr.treeid) + if sim(when).asInstanceOf[Boolean] then sim(thenCode.expr) else otherConds.headOption match - case None => exprMap(otherwise.rootTreeId) + case None => sim(otherwise.expr) case Some(cond) => whenHelper( when = cond.expr, From f1177b696655926c8178e07d31064c476aa2a919 Mon Sep 17 00:00:00 2001 From: spamegg Date: Mon, 14 Jul 2025 19:17:55 +0300 Subject: [PATCH 05/13] add Read support to expression Simulator --- .../cyfra/e2e/interpreter/SimulateTests.scala | 20 ++++- .../cyfra/interpreter/Interpreter.scala | 46 +++++------ .../cyfra/interpreter/SimContext.scala | 16 ++++ .../cyfra/interpreter/Simulate.scala | 77 +++++++++---------- 4 files changed, 92 insertions(+), 67 deletions(-) create mode 100644 cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index b8dcfe1d..3e935e94 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -1,8 +1,9 @@ package io.computenode.cyfra.e2e.interpreter import io.computenode.cyfra.interpreter.*, Result.* -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.{*, given}, binding.{ReadBuffer, GBuffer} import Value.FromExpr.fromExpr, control.Scope +import izumi.reflect.Tag class SimulateE2eTest extends munit.FunSuite: test("simulate binary operation arithmetic"): @@ -64,3 +65,20 @@ class SimulateE2eTest extends munit.FunSuite: val res = Simulate.sim(sum) val exp = 1000002 assert(res == exp, s"Expected $exp, got $res") + + test("simulate ReadBuffer"): + // We fake a GBuffer with an array + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (0 until 1024).toArray[Result] + + given sc: SimContext = SimContext() + sc.addBuffer(buffer, array) + + val expr = ReadBuffer(buffer, 128) + val res = Simulate.sim(expr)(using sc) + val exp = 128 + assert(res == exp, s"Expected $exp, got $res") + + // the context should keep track of the read + assert(sc.reads.contains(Read(buffer, 128))) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 48f07619..2c28b8b0 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -4,46 +4,38 @@ import io.computenode.cyfra.dsl.{*, given} import binding.*, Value.*, gio.GIO, GIO.* import izumi.reflect.Tag +case class Write(buffer: GBuffer[?], index: Int, value: Any) +case class InvocResult( + invocId: Int, + instructions: List[Expression[?]] = Nil, + values: List[Any] = Nil, + writes: List[Write] = Nil, + reads: List[Read] = Nil, +) +case class InterpretResult(invocs: List[InvocResult] = Nil): + def add(that: InterpretResult) = InterpretResult(that.invocs ::: invocs) + object Interpreter: - case class Write(buffer: GBuffer[?], index: Int, value: Any) - case class Read(buffer: GBuffer[?], index: Int) - case class InvocSimResult( - invocId: Int, - instructions: List[Expression[?]] = Nil, - values: List[Any] = Nil, - writes: List[Write] = Nil, - reads: List[Read] = Nil, - ) - case class SimResult(invocs: List[InvocSimResult] = Nil): - def add(that: SimResult) = SimResult(that.invocs ::: invocs) - - case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] - - // val bufferA = SimGBuffer[Int32]() - // val gbuffers = Map[GBuffer[?], Array[Int32]](bufferA -> Array.fill(1024)(0)) - // val expression = bufferA.read(0) + 2 - // val res = Simulate.sim(expression, 1024) -> SimulationResult(???) - - def interpretPure(gio: Pure[?]): SimResult = gio match + def interpretPure(gio: Pure[?])(using SimContext): InterpretResult = gio match case Pure(value) => val id = Simulate.sim(invocationId).asInstanceOf[Int] - val invocSimRes = InvocSimResult(invocId = id, values = List(value)) - SimResult(List(invocSimRes)) + val invocSimRes = InvocResult(invocId = id, values = List(value)) + InterpretResult(List(invocSimRes)) - def interpretWriteBuffer(gio: WriteBuffer[?]): SimResult = gio match + def interpretWriteBuffer(gio: WriteBuffer[?]): InterpretResult = gio match case WriteBuffer(buffer, index, value) => ??? - def interpretWriteUniform(gio: WriteUniform[?]): SimResult = gio match + def interpretWriteUniform(gio: WriteUniform[?]): InterpretResult = gio match case WriteUniform(uniform, value) => ??? - def interpretOne(gio: GIO[?]): SimResult = gio match + def interpretOne(gio: GIO[?])(using SimContext): InterpretResult = gio match case p: Pure[?] => interpretPure(p) case wb: WriteBuffer[?] => interpretWriteBuffer(wb) case wu: WriteUniform[?] => interpretWriteUniform(wu) case _ => throw IllegalArgumentException("interpret: invalid GIO") @annotation.tailrec - def interpretMany(gios: List[GIO[?]], simRes: SimResult): SimResult = gios match + def interpretMany(gios: List[GIO[?]], simRes: InterpretResult)(using SimContext): InterpretResult = gios match case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, simRes) case Repeat(n, f) :: tail => val int = Simulate.sim(n).asInstanceOf[Int] @@ -52,4 +44,4 @@ object Interpreter: case head :: tail => interpretMany(tail, simRes.add(interpretOne(head))) case Nil => simRes - def interpret(gio: GIO[?]): SimResult = interpretMany(List(gio), SimResult()) + def interpret(gio: GIO[?])(using SimContext): InterpretResult = interpretMany(List(gio), InterpretResult()) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala new file mode 100644 index 00000000..c2007d96 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.GBuffer +import Result.Result + +import scala.collection.mutable.Map as MMap + +case class SimContext(exprMap: MMap[Int, Result] = MMap(), bufMap: MMap[GBuffer[?], Array[Result]] = MMap(), var reads: List[Read] = Nil): + def addBuffer(buffer: GBuffer[?], array: Array[Result]): Unit = bufMap.addOne(buffer -> array) + def addResult(treeid: Int, result: Result): Unit = exprMap.addOne(treeid -> result) + def addRead(buffer: GBuffer[?], index: Int): Unit = reads ::= Read(buffer, index) + def read(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) + def lookup(treeid: Int): Result = exprMap(treeid) + +case class Read(buffer: GBuffer[?], index: Int) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 37dee1c8..890c8052 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -11,22 +11,18 @@ object Simulate: import Result.* def sim(v: Value): Result = sim(v.tree) // helpful wrapper for Value instead of Expression - - def sim(e: Expression[?]): Result = - val exprMap = MMap.empty[Int, Result] // treeid of Expr -> result of evaluating that Expr - val blocks = buildBlock(e) - simIterate(blocks)(using exprMap) + def sim(e: Expression[?])(using sc: SimContext = SimContext()): Result = simIterate(buildBlock(e)) @annotation.tailrec - def simIterate(blocks: List[Expression[?]])(using exprMap: MMap[Int, Result]): Result = blocks match + def simIterate(blocks: List[Expression[?]])(using sc: SimContext): Result = blocks match case head :: Nil => simOne(head) case head :: next => val result = simOne(head) - exprMap.addOne(head.treeid -> result) + sc.addResult(head.treeid, result) simIterate(next) case Nil => ??? // should not happen - def simOne(e: Expression[?])(using exprMap: MMap[Int, Result]): Result = e match + def simOne(e: Expression[?])(using sc: SimContext): Result = e match case e: PhantomExpression[?] => simPhantom(e) case Negate(a) => simValue(a).negate case e: BinaryOpExpression[?] => simBinOp(e) @@ -57,46 +53,46 @@ object Simulate: case e: GetField[?, ?] => simGetField(e) case _ => throw IllegalArgumentException("sim: wrong argument") - private def simPhantom(e: PhantomExpression[?])(using exprMap: MMap[Int, Result]): Result = e match + private def simPhantom(e: PhantomExpression[?])(using sc: SimContext): Result = e match case CurrentElem(tid: Int) => ??? case AggregateElem(tid: Int) => ??? - private def simBinOp(e: BinaryOpExpression[?])(using exprMap: MMap[Int, Result]): Result = e match + private def simBinOp(e: BinaryOpExpression[?])(using sc: SimContext): Result = e match case Sum(a, b) => simValue(a).add(simValue(b)) // scalar or vector case Diff(a, b) => simValue(a).sub(simValue(b)) // scalar or vector case Mul(a, b) => simScalar(a).mul(simScalar(b)) case Div(a, b) => simScalar(a).div(simScalar(b)) case Mod(a, b) => simScalar(a).mod(simScalar(b)) - private def simBitwiseOp(e: BitwiseOpExpression[?])(using exprMap: MMap[Int, Result]): Int = e match + private def simBitwiseOp(e: BitwiseOpExpression[?])(using sc: SimContext): Int = e match case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) case BitwiseNot(a) => simScalar(a).bitNeg case ShiftLeft(a, by) => simScalar(a).shiftLeft(simScalar(by)) case ShiftRight(a, by) => simScalar(a).shiftRight(simScalar(by)) - private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using exprMap: MMap[Int, Result]): Int = e match + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using sc: SimContext): Int = e match case BitwiseAnd(a, b) => simScalar(a).bitAnd(simScalar(b)) case BitwiseOr(a, b) => simScalar(a).bitOr(simScalar(b)) case BitwiseXor(a, b) => simScalar(a).bitXor(simScalar(b)) - private def simCompareOp(e: ComparisonOpExpression[?])(using exprMap: MMap[Int, Result]): Boolean = e match + private def simCompareOp(e: ComparisonOpExpression[?])(using sc: SimContext): Boolean = e match case GreaterThan(a, b) => simScalar(a).gt(simScalar(b)) case LessThan(a, b) => simScalar(a).lt(simScalar(b)) case GreaterThanEqual(a, b) => simScalar(a).gteq(simScalar(b)) case LessThanEqual(a, b) => simScalar(a).lteq(simScalar(b)) case Equal(a, b) => simScalar(a).eql(simScalar(b)) - private def simConvert(e: ConvertExpression[?, ?])(using exprMap: MMap[Int, Result]): Float | Int = e match + private def simConvert(e: ConvertExpression[?, ?])(using sc: SimContext): Float | Int = e match case ToFloat32(a) => - exprMap(a.treeid) match + sc.lookup(a.treeid) match case f: Float => f case _ => throw IllegalArgumentException("ToFloat32: wrong argument type") case ToInt32(a) => - exprMap(a.treeid) match + sc.lookup(a.treeid) match case n: Int => n case _ => throw IllegalArgumentException("ToInt32: wrong argument type") case ToUInt32(a) => - exprMap(a.treeid) match + sc.lookup(a.treeid) match case n: Int => n case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") @@ -106,24 +102,23 @@ object Simulate: case ConstUInt32(value) => value case ConstGB(value) => value - private def simValue(v: Value)(using exprMap: MMap[Int, Result]): Result = v match + private def simValue(v: Value)(using sc: SimContext): Result = v match case v: Scalar => simScalar(v) case v: Vec[?] => simVector(v) - private def simScalar(v: Scalar)(using exprMap: MMap[Int, Result]): ScalarRes = v match - case v: FloatType => exprMap(v.tree.treeid).asInstanceOf[Float] - case v: IntType => exprMap(v.tree.treeid).asInstanceOf[Int] - case v: UIntType => exprMap(v.tree.treeid).asInstanceOf[Int] - case GBoolean(source) => exprMap(source.treeid).asInstanceOf[Boolean] + private def simScalar(v: Scalar)(using sc: SimContext): ScalarRes = v match + case v: FloatType => sc.lookup(v.tree.treeid).asInstanceOf[Float] + case v: IntType => sc.lookup(v.tree.treeid).asInstanceOf[Int] + case v: UIntType => sc.lookup(v.tree.treeid).asInstanceOf[Int] + case GBoolean(source) => sc.lookup(source.treeid).asInstanceOf[Boolean] - private def simVector(v: Vec[?])(using exprMap: MMap[Int, Result]) = v match - case Vec2(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] - case Vec3(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] - case Vec4(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + private def simVector(v: Vec[?])(using sc: SimContext) = v match + case Vec2(tree) => sc.lookup(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec3(tree) => sc.lookup(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec4(tree) => sc.lookup(tree.treeid).asInstanceOf[Vector[ScalarRes]] - private def simExtFunc(fn: FunctionName, args: List[Result])(using exprMap: MMap[Int, Result]): Result = ??? - private def simFunc(fn: FnIdentifier, body: Result, args: List[Result])(using exprMap: MMap[Int, Result]): Result = ??? - private def simScope(body: Scope[?])(using exprMap: MMap[Int, Result]) = exprMap(body.rootTreeId) + private def simExtFunc(fn: FunctionName, args: List[Result])(using sc: SimContext): Result = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result])(using sc: SimContext): Result = ??? @annotation.tailrec private def whenHelper( @@ -132,7 +127,7 @@ object Simulate: otherConds: List[Scope[GBoolean]], otherCaseCodes: List[Scope[?]], otherwise: Scope[?], - )(using exprMap: MMap[Int, Result]): Result = + )(using sc: SimContext): Result = if sim(when).asInstanceOf[Boolean] then sim(thenCode.expr) else otherConds.headOption match @@ -146,24 +141,28 @@ object Simulate: otherwise = otherwise, ) - private def simWhen(e: WhenExpr[?])(using exprMap: MMap[Int, Result]): Result = e match + private def simWhen(e: WhenExpr[?])(using sc: SimContext): Result = e match case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise) - private def simReadBuffer(buf: ReadBuffer[?])(using exprMap: MMap[Int, Result]): Result = buf match - case ReadBuffer(buffer, index) => ??? + private def simReadBuffer(buf: ReadBuffer[?])(using sc: SimContext): Result = buf match + case ReadBuffer(buffer, index) => + val i = sim(index).asInstanceOf[Int] + // sc.addBuffer(buffer, Array.fill(1024)(0)) // add a fake buffer represented by an array + sc.addRead(buffer, i) + sc.read(buffer, i) - private def simReadUniform(uni: ReadUniform[?])(using exprMap: MMap[Int, Result]): Result = uni match + private def simReadUniform(uni: ReadUniform[?])(using sc: SimContext): Result = uni match case ReadUniform(uniform) => ??? - private def simGArrayElem(gElem: GArrayElem[?])(using exprMap: MMap[Int, Result]): Result = gElem match + private def simGArrayElem(gElem: GArrayElem[?])(using sc: SimContext): Result = gElem match case GArrayElem(index, i) => ??? - private def simFoldSeq(seq: FoldSeq[?, ?])(using exprMap: MMap[Int, Result]): Result = seq match + private def simFoldSeq(seq: FoldSeq[?, ?])(using sc: SimContext): Result = seq match case FoldSeq(zero, fn, seq) => ??? - private def simComposeStruct(cs: ComposeStruct[?])(using exprMap: MMap[Int, Result]): Result = cs match + private def simComposeStruct(cs: ComposeStruct[?])(using sc: SimContext): Result = cs match case ComposeStruct(fields, resultSchema) => ??? - private def simGetField(gf: GetField[?, ?])(using exprMap: MMap[Int, Result]): Result = gf match + private def simGetField(gf: GetField[?, ?])(using sc: SimContext): Result = gf match case GetField(struct, fieldIndex) => ??? From c311c95175cdf9ad197a0220e1a787c28addff67 Mon Sep 17 00:00:00 2001 From: spamegg Date: Tue, 15 Jul 2025 18:13:36 +0300 Subject: [PATCH 06/13] start work on Write support in Interpreter --- .../cyfra/interpreter/Interpreter.scala | 48 +++++++++++-------- .../cyfra/interpreter/Simulate.scala | 1 - 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 2c28b8b0..f4f1b929 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -5,6 +5,7 @@ import binding.*, Value.*, gio.GIO, GIO.* import izumi.reflect.Tag case class Write(buffer: GBuffer[?], index: Int, value: Any) + case class InvocResult( invocId: Int, instructions: List[Expression[?]] = Nil, @@ -12,36 +13,45 @@ case class InvocResult( writes: List[Write] = Nil, reads: List[Read] = Nil, ) + case class InterpretResult(invocs: List[InvocResult] = Nil): def add(that: InterpretResult) = InterpretResult(that.invocs ::: invocs) object Interpreter: - def interpretPure(gio: Pure[?])(using SimContext): InterpretResult = gio match + def interpretPure(gio: Pure[?], invocId: Int)(using sc: SimContext): InterpretResult = gio match case Pure(value) => - val id = Simulate.sim(invocationId).asInstanceOf[Int] - val invocSimRes = InvocResult(invocId = id, values = List(value)) - InterpretResult(List(invocSimRes)) - - def interpretWriteBuffer(gio: WriteBuffer[?]): InterpretResult = gio match - case WriteBuffer(buffer, index, value) => ??? - - def interpretWriteUniform(gio: WriteUniform[?]): InterpretResult = gio match + InterpretResult: + List(InvocResult(invocId, ???, List(value), ???, sc.reads)) + + def interpretWriteBuffer(gio: WriteBuffer[?], invocId: Int)(using sc: SimContext): InterpretResult = gio match + case WriteBuffer(buffer, index, value) => + val i = Simulate.sim(index).asInstanceOf[Int] // reads happening here? + val res = Simulate.sim(value) // reads happening here? (has its own SimContext) + sc.bufMap(buffer)(i) = res + val writes = List(Write(buffer, i, res)) + val reads = ??? // get all the reads from the SimContext? + InterpretResult: + List(InvocResult(invocId, ???, List(value), writes, reads)) + + def interpretWriteUniform(gio: WriteUniform[?], invocId: Int)(using sc: SimContext): InterpretResult = gio match case WriteUniform(uniform, value) => ??? - def interpretOne(gio: GIO[?])(using SimContext): InterpretResult = gio match - case p: Pure[?] => interpretPure(p) - case wb: WriteBuffer[?] => interpretWriteBuffer(wb) - case wu: WriteUniform[?] => interpretWriteUniform(wu) - case _ => throw IllegalArgumentException("interpret: invalid GIO") + def interpretOne(gio: GIO[?])(using SimContext): InterpretResult = + val invocId = Simulate.sim(invocationId).asInstanceOf[Int] + gio match + case p: Pure[?] => interpretPure(p, invocId) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, invocId) + case wu: WriteUniform[?] => interpretWriteUniform(wu, invocId) + case _ => throw IllegalArgumentException("interpretOne: invalid GIO") @annotation.tailrec - def interpretMany(gios: List[GIO[?]], simRes: InterpretResult)(using SimContext): InterpretResult = gios match - case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, simRes) + def interpretMany(gios: List[GIO[?]], res: InterpretResult)(using SimContext): InterpretResult = gios match + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, res) case Repeat(n, f) :: tail => val int = Simulate.sim(n).asInstanceOf[Int] val newGios = (0 until int).map(i => f(i)).toList - interpretMany(newGios ::: tail, simRes) - case head :: tail => interpretMany(tail, simRes.add(interpretOne(head))) - case Nil => simRes + interpretMany(newGios ::: tail, res) + case head :: tail => interpretMany(tail, res.add(interpretOne(head))) + case Nil => res def interpret(gio: GIO[?])(using SimContext): InterpretResult = interpretMany(List(gio), InterpretResult()) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 890c8052..56a27ea0 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -148,7 +148,6 @@ object Simulate: private def simReadBuffer(buf: ReadBuffer[?])(using sc: SimContext): Result = buf match case ReadBuffer(buffer, index) => val i = sim(index).asInstanceOf[Int] - // sc.addBuffer(buffer, Array.fill(1024)(0)) // add a fake buffer represented by an array sc.addRead(buffer, i) sc.read(buffer, i) From 600c485b939f8211ba5b1b0dde11ae4eabe49ea3 Mon Sep 17 00:00:00 2001 From: spamegg Date: Thu, 17 Jul 2025 18:46:44 +0300 Subject: [PATCH 07/13] change SimContext to be immutable, add complex when test --- .../cyfra/e2e/interpreter/SimulateTests.scala | 22 +- .../e2e/interpreter/SimulateWhenTests.scala | 49 ++- .../cyfra/interpreter/Interpreter.scala | 2 +- .../cyfra/interpreter/SimContext.scala | 14 +- .../cyfra/interpreter/Simulate.scala | 289 +++++++++++------- 5 files changed, 247 insertions(+), 129 deletions(-) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index 3e935e94..531adc1e 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -18,29 +18,30 @@ class SimulateE2eTest extends munit.FunSuite: val e3 = Mul(f, fromExpr(e2)) val e4 = Div(fromExpr(e3), d) val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4) - val result = Simulate.sim(expr) + val (result, _) = Simulate.sim(expr) val expected = 2 assert(result == expected, s"Expected $expected, got $result") test("simulate vec4, scalar, dot, extract scalar"): val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) - val res1 = Simulate.sim(v1) + val (res1, _) = Simulate.sim(v1) val exp1 = Vector(1f, 2f, 3f, 4f) assert(res1 == exp1, s"Expected $exp1, got $res1") val i: Int32 = 2 - val res = Simulate.sim(ExtractScalar(fromExpr(v1), i)) + val (res, _) = Simulate.sim(ExtractScalar(fromExpr(v1), i)) val exp = 3f assert(res == exp, s"Expected $exp, got $res") val v2 = ScalarProd(fromExpr(v1), -1f) - val res2 = Simulate.sim(v2) + val (res2, _) = Simulate.sim(v2) val exp2 = Vector(-1f, -2f, -3f, -4f) assert(res2 == exp2, s"Expected $exp2, got $res2") val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f) val dot = DotProd(fromExpr(v1), fromExpr(v3)) - val res3 = Simulate.sim(dot).asInstanceOf[Float] + val (res3a, _) = Simulate.sim(dot) + val res3 = res3a.asInstanceOf[Float] val exp3 = 0f assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3") @@ -54,7 +55,7 @@ class SimulateE2eTest extends munit.FunSuite: val or = BitwiseOr(fromExpr(left), fromExpr(right)) val xor = BitwiseXor(fromExpr(and), fromExpr(or)) - val res = Simulate.sim(xor) + val (res, _) = Simulate.sim(xor) val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) assert(res == exp, s"Expected $exp, got $res") @@ -62,7 +63,7 @@ class SimulateE2eTest extends munit.FunSuite: val a: Int32 = 1 var sum = Sum(a, a) // 2 for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) - val res = Simulate.sim(sum) + val (res, _) = Simulate.sim(sum) val exp = 1000002 assert(res == exp, s"Expected $exp, got $res") @@ -72,13 +73,12 @@ class SimulateE2eTest extends munit.FunSuite: val buffer = SimGBuffer[Int32]() val array = (0 until 1024).toArray[Result] - given sc: SimContext = SimContext() - sc.addBuffer(buffer, array) + val sc = SimContext().addBuffer(buffer, array) val expr = ReadBuffer(buffer, 128) - val res = Simulate.sim(expr)(using sc) + val (res, newSc) = Simulate.sim(expr, sc) val exp = 128 assert(res == exp, s"Expected $exp, got $res") // the context should keep track of the read - assert(sc.reads.contains(Read(buffer, 128))) + assert(newSc.reads.contains(Read(buffer, 128)), "missing read") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala index a2dba264..841192b8 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -2,7 +2,8 @@ package io.computenode.cyfra.e2e.interpreter import io.computenode.cyfra.interpreter.*, Result.* import io.computenode.cyfra.dsl.{*, given} -import Value.FromExpr.fromExpr, control.Scope +import Value.FromExpr.fromExpr, control.Scope, binding.{GBuffer, ReadBuffer} +import izumi.reflect.Tag class SimulateWhenE2eTest extends munit.FunSuite: test("simulate when"): @@ -13,7 +14,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val res1 = Simulate.sim(expr1) + val (res1, _) = Simulate.sim(expr1) val exp1 = 1 assert(res1 == exp1, s"Expected $exp1, got $res1") @@ -25,7 +26,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val res2 = Simulate.sim(expr2) + val (res2, _) = Simulate.sim(expr2) val exp2 = 2 assert(res2 == exp2, s"Expected $exp2, got $res2") @@ -37,7 +38,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val res3 = Simulate.sim(expr3) + val (res3, _) = Simulate.sim(expr3) val exp3 = 4 assert(res3 == exp3, s"Expected $exp3, got $res3") @@ -49,6 +50,44 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val res4 = Simulate.sim(expr4) + val (res4, _) = Simulate.sim(expr4) val exp4 = 3 assert(res4 == exp4, s"Expected $exp4, got $res4") + + test("simulate mixed arithmetic, reads and when"): + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (128 until 0 by -1).toArray[Result] + + val sc = SimContext().addBuffer(buffer, array) + + val a: Int32 = 32 + val b: Int32 = 64 + val c: Int32 = 4 + + val readExpr1 = ReadBuffer(buffer, a) // 96 + val expr1 = Mul(c, fromExpr(readExpr1)) // 4 * 96 = 384 + + val readExpr2 = ReadBuffer(buffer, b) // 64 + val expr2 = Sum(c, fromExpr(readExpr2)) // 4 + 64 = 68 + + val expr3 = Mod(fromExpr(expr2), 5) // 68 % 5 = 3 + + val cond1 = fromExpr(expr1) <= fromExpr(expr2) // 384 <= 68 false + val cond2 = Equal(fromExpr(expr1), fromExpr(expr2)) // 384 == 68 false + val cond3 = GreaterThanEqual(fromExpr(expr3), fromExpr(expr2)) // 3 >= 68 false + + val expr = WhenExpr( + when = cond1, // false + thenCode = Scope(expr1), + otherConds = List(Scope(cond2), Scope(cond3)), // false false + otherCaseCodes = List(Scope(expr1), Scope(expr2)), // 384, 68 + otherwise = Scope(expr3) // 3 + ) + val (res, newSc) = Simulate.sim(expr, sc) + val exp = 3 + assert(res == exp, s"Expected $exp, got $res") + + // There should be 2 reads in the simulation context + assert(newSc.reads.contains(Read(buffer, 32))) + assert(newSc.reads.contains(Read(buffer, 64))) \ No newline at end of file diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index f4f1b929..12a22ff9 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -27,7 +27,7 @@ object Interpreter: case WriteBuffer(buffer, index, value) => val i = Simulate.sim(index).asInstanceOf[Int] // reads happening here? val res = Simulate.sim(value) // reads happening here? (has its own SimContext) - sc.bufMap(buffer)(i) = res + // sc.bufMap(buffer)(i) = res val writes = List(Write(buffer, i, res)) val reads = ??? // get all the reads from the SimContext? InterpretResult: diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala index c2007d96..7e919211 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -4,13 +4,11 @@ import io.computenode.cyfra.dsl.{*, given} import binding.GBuffer import Result.Result -import scala.collection.mutable.Map as MMap - -case class SimContext(exprMap: MMap[Int, Result] = MMap(), bufMap: MMap[GBuffer[?], Array[Result]] = MMap(), var reads: List[Read] = Nil): - def addBuffer(buffer: GBuffer[?], array: Array[Result]): Unit = bufMap.addOne(buffer -> array) - def addResult(treeid: Int, result: Result): Unit = exprMap.addOne(treeid -> result) - def addRead(buffer: GBuffer[?], index: Int): Unit = reads ::= Read(buffer, index) - def read(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) - def lookup(treeid: Int): Result = exprMap(treeid) +case class SimContext(exprMap: Map[Int, Result] = Map(), bufMap: Map[GBuffer[?], Array[Result]] = Map(), reads: List[Read] = Nil, writes: List[Write] = Nil): + def addBuffer(buffer: GBuffer[?], array: Array[Result]): SimContext = copy(bufMap = bufMap + (buffer -> array)) + def addResult(treeid: Int, result: Result): SimContext = copy(exprMap = exprMap + (treeid -> result)) + def addRead(buffer: GBuffer[?], index: Int): SimContext = copy(reads = Read(buffer, index) :: reads) + def lookupRead(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) + def lookupExpr(treeid: Int): Result = exprMap(treeid) case class Read(buffer: GBuffer[?], index: Int) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 56a27ea0..bb9434e0 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -10,115 +10,193 @@ import collection.mutable.Map as MMap object Simulate: import Result.* - def sim(v: Value): Result = sim(v.tree) // helpful wrapper for Value instead of Expression - def sim(e: Expression[?])(using sc: SimContext = SimContext()): Result = simIterate(buildBlock(e)) + def sim(v: Value): (Result, SimContext) = sim(v.tree) // helpful wrapper for Value instead of Expression + def sim(e: Expression[?], sc: SimContext = SimContext()): (Result, SimContext) = simIterate(buildBlock(e), sc) @annotation.tailrec - def simIterate(blocks: List[Expression[?]])(using sc: SimContext): Result = blocks match - case head :: Nil => simOne(head) + def simIterate(blocks: List[Expression[?]], sc: SimContext): (Result, SimContext) = blocks match + case head :: Nil => simOne(head, sc) case head :: next => - val result = simOne(head) - sc.addResult(head.treeid, result) - simIterate(next) + val (result, sc1) = simOne(head, sc) + val newSc = sc1.addResult(head.treeid, result) + simIterate(next, newSc) case Nil => ??? // should not happen - def simOne(e: Expression[?])(using sc: SimContext): Result = e match - case e: PhantomExpression[?] => simPhantom(e) - case Negate(a) => simValue(a).negate - case e: BinaryOpExpression[?] => simBinOp(e) - case ScalarProd(a, b) => simVector(a).scale(simScalar(b)) - case DotProd(a, b) => simVector(a).dot(simVector(b)) - case e: BitwiseOpExpression[?] => simBitwiseOp(e) - case e: ComparisonOpExpression[?] => simCompareOp(e) - case And(a, b) => simScalar(a) && simScalar(b) - case Or(a, b) => simScalar(a) || simScalar(b) - case Not(a) => simScalar(a).negate - case ExtractScalar(a, i) => simVector(a).apply(simValue(i).asInstanceOf[Int]) - case e: ConvertExpression[?, ?] => simConvert(e) - case e: Const[?] => simConst(e) - case ComposeVec2(a, b) => Vector(simScalar(a), simScalar(b)) - case ComposeVec3(a, b, c) => Vector(simScalar(a), simScalar(b), simScalar(c)) - case ComposeVec4(a, b, c, d) => Vector(simScalar(a), simScalar(b), simScalar(c), simScalar(d)) - case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue)) - case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue)) + def simOne(e: Expression[?], sc: SimContext): (Result, SimContext) = e match + case e: PhantomExpression[?] => simPhantom(e, sc) + case Negate(a) => + val (res, newSc) = simValue(a, sc) + (res.negate, newSc) + case e: BinaryOpExpression[?] => simBinOp(e, sc) + case ScalarProd(a, b) => + val (resA, scA) = simVector(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.scale(resB), scB) + case DotProd(a, b) => + val (resA, scA) = simVector(a, sc) + val (resB, scB) = simVector(b, scA) + (resA.dot(resB), scB) + case e: BitwiseOpExpression[?] => simBitwiseOp(e, sc) + case e: ComparisonOpExpression[?] => simCompareOp(e, sc) + case And(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA && resB, scB) + case Or(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA || resB, scB) + case Not(a) => + val (res, newSc) = simScalar(a, sc) + (res.negate, newSc) + case ExtractScalar(a, i) => + val (resA, scA) = simVector(a, sc) + val (index, scI) = simValue(i, scA) + (resA.apply(index.asInstanceOf[Int]), scI) + case e: ConvertExpression[?, ?] => simConvert(e, sc) + case e: Const[?] => (simConst(e), sc) + case ComposeVec2(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (Vector(resA, resB), scB) + case ComposeVec3(a, b, c) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + val (resC, scC) = simScalar(c, scB) + (Vector(resA, resB, resC), scC) + case ComposeVec4(a, b, c, d) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + val (resC, scC) = simScalar(c, scB) + val (resD, scD) = simScalar(d, scC) + (Vector(resA, resB, resC, resD), scD) + case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue), sc) + case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue), sc) case InvocationId => ??? case Pass(value) => ??? case Dynamic(source) => ??? - case e: WhenExpr[?] => simWhen(e) - case e: ReadBuffer[?] => simReadBuffer(e) - case e: ReadUniform[?] => simReadUniform(e) - case e: GArrayElem[?] => simGArrayElem(e) - case e: FoldSeq[?, ?] => simFoldSeq(e) - case e: ComposeStruct[?] => simComposeStruct(e) - case e: GetField[?, ?] => simGetField(e) + case e: WhenExpr[?] => simWhen(e, sc) + case e: ReadBuffer[?] => simReadBuffer(e, sc) + case e: ReadUniform[?] => simReadUniform(e, sc) + case e: GArrayElem[?] => simGArrayElem(e, sc) + case e: FoldSeq[?, ?] => simFoldSeq(e, sc) + case e: ComposeStruct[?] => simComposeStruct(e, sc) + case e: GetField[?, ?] => simGetField(e, sc) case _ => throw IllegalArgumentException("sim: wrong argument") - private def simPhantom(e: PhantomExpression[?])(using sc: SimContext): Result = e match + private def simPhantom(e: PhantomExpression[?], sc: SimContext): (Result, SimContext) = e match case CurrentElem(tid: Int) => ??? case AggregateElem(tid: Int) => ??? - private def simBinOp(e: BinaryOpExpression[?])(using sc: SimContext): Result = e match - case Sum(a, b) => simValue(a).add(simValue(b)) // scalar or vector - case Diff(a, b) => simValue(a).sub(simValue(b)) // scalar or vector - case Mul(a, b) => simScalar(a).mul(simScalar(b)) - case Div(a, b) => simScalar(a).div(simScalar(b)) - case Mod(a, b) => simScalar(a).mod(simScalar(b)) - - private def simBitwiseOp(e: BitwiseOpExpression[?])(using sc: SimContext): Int = e match - case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) - case BitwiseNot(a) => simScalar(a).bitNeg - case ShiftLeft(a, by) => simScalar(a).shiftLeft(simScalar(by)) - case ShiftRight(a, by) => simScalar(a).shiftRight(simScalar(by)) - - private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using sc: SimContext): Int = e match - case BitwiseAnd(a, b) => simScalar(a).bitAnd(simScalar(b)) - case BitwiseOr(a, b) => simScalar(a).bitOr(simScalar(b)) - case BitwiseXor(a, b) => simScalar(a).bitXor(simScalar(b)) - - private def simCompareOp(e: ComparisonOpExpression[?])(using sc: SimContext): Boolean = e match - case GreaterThan(a, b) => simScalar(a).gt(simScalar(b)) - case LessThan(a, b) => simScalar(a).lt(simScalar(b)) - case GreaterThanEqual(a, b) => simScalar(a).gteq(simScalar(b)) - case LessThanEqual(a, b) => simScalar(a).lteq(simScalar(b)) - case Equal(a, b) => simScalar(a).eql(simScalar(b)) - - private def simConvert(e: ConvertExpression[?, ?])(using sc: SimContext): Float | Int = e match + private def simBinOp(e: BinaryOpExpression[?], sc: SimContext): (Result, SimContext) = e match + case Sum(a, b) => // scalar or vector + val (resA, scA) = simValue(a, sc) + val (resB, scB) = simValue(b, scA) + (resA.add(resB), scB) + case Diff(a, b) => // scalar or vector + val (resA, scA) = simValue(a, sc) + val (resB, scB) = simValue(b, scA) + (resA.sub(resB), scB) + case Mul(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.mul(resB), scB) + case Div(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.div(resB), scB) + case Mod(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.mod(resB), scB) + + private def simBitwiseOp(e: BitwiseOpExpression[?], sc: SimContext): (Int, SimContext) = e match + case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e, sc) + case BitwiseNot(a) => + val (res, newSc) = simScalar(a, sc) + (res.bitNeg, newSc) + case ShiftLeft(a, by) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(by, scA) + (resA.shiftLeft(resB), scB) + case ShiftRight(a, by) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(by, scA) + (resA.shiftRight(resB), scB) + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?], sc: SimContext): (Int, SimContext) = e match + case BitwiseAnd(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.bitAnd(resB), scB) + case BitwiseOr(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.bitOr(resB), scB) + case BitwiseXor(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.bitXor(resB), scB) + + private def simCompareOp(e: ComparisonOpExpression[?], sc: SimContext): (Boolean, SimContext) = e match + case GreaterThan(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.gt(resB), scB) + case LessThan(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.lt(resB), scB) + case GreaterThanEqual(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.gteq(resB), scB) + case LessThanEqual(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.lteq(resB), scB) + case Equal(a, b) => + val (resA, scA) = simScalar(a, sc) + val (resB, scB) = simScalar(b, scA) + (resA.eql(resB), scB) + + private def simConvert(e: ConvertExpression[?, ?], sc: SimContext): (Float | Int, SimContext) = e match case ToFloat32(a) => - sc.lookup(a.treeid) match - case f: Float => f + sc.lookupExpr(a.treeid) match + case f: Float => (f, sc) case _ => throw IllegalArgumentException("ToFloat32: wrong argument type") case ToInt32(a) => - sc.lookup(a.treeid) match - case n: Int => n + sc.lookupExpr(a.treeid) match + case n: Int => (n, sc) case _ => throw IllegalArgumentException("ToInt32: wrong argument type") case ToUInt32(a) => - sc.lookup(a.treeid) match - case n: Int => n + sc.lookupExpr(a.treeid) match + case n: Int => (n, sc) case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") - private def simConst(e: Const[?]): ScalarRes = e match + private def simConst(e: Const[?]): ScalarRes = e match // no context needed case ConstFloat32(value) => value case ConstInt32(value) => value case ConstUInt32(value) => value case ConstGB(value) => value - private def simValue(v: Value)(using sc: SimContext): Result = v match - case v: Scalar => simScalar(v) - case v: Vec[?] => simVector(v) + private def simValue(v: Value, sc: SimContext): (Result, SimContext) = v match + case v: Scalar => simScalar(v, sc) + case v: Vec[?] => simVector(v, sc) - private def simScalar(v: Scalar)(using sc: SimContext): ScalarRes = v match - case v: FloatType => sc.lookup(v.tree.treeid).asInstanceOf[Float] - case v: IntType => sc.lookup(v.tree.treeid).asInstanceOf[Int] - case v: UIntType => sc.lookup(v.tree.treeid).asInstanceOf[Int] - case GBoolean(source) => sc.lookup(source.treeid).asInstanceOf[Boolean] + private def simScalar(v: Scalar, sc: SimContext): (ScalarRes, SimContext) = v match + case v: FloatType => (sc.lookupExpr(v.tree.treeid).asInstanceOf[Float], sc) + case v: IntType => (sc.lookupExpr(v.tree.treeid).asInstanceOf[Int], sc) + case v: UIntType => (sc.lookupExpr(v.tree.treeid).asInstanceOf[Int], sc) + case GBoolean(source) => (sc.lookupExpr(source.treeid).asInstanceOf[Boolean], sc) - private def simVector(v: Vec[?])(using sc: SimContext) = v match - case Vec2(tree) => sc.lookup(tree.treeid).asInstanceOf[Vector[ScalarRes]] - case Vec3(tree) => sc.lookup(tree.treeid).asInstanceOf[Vector[ScalarRes]] - case Vec4(tree) => sc.lookup(tree.treeid).asInstanceOf[Vector[ScalarRes]] + private def simVector(v: Vec[?], sc: SimContext): (Vector[ScalarRes], SimContext) = v match + case Vec2(tree) => (sc.lookupExpr(tree.treeid).asInstanceOf[Vector[ScalarRes]], sc) + case Vec3(tree) => (sc.lookupExpr(tree.treeid).asInstanceOf[Vector[ScalarRes]], sc) + case Vec4(tree) => (sc.lookupExpr(tree.treeid).asInstanceOf[Vector[ScalarRes]], sc) - private def simExtFunc(fn: FunctionName, args: List[Result])(using sc: SimContext): Result = ??? - private def simFunc(fn: FnIdentifier, body: Result, args: List[Result])(using sc: SimContext): Result = ??? + private def simExtFunc(fn: FunctionName, args: List[Result], sc: SimContext): (Result, SimContext) = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result], sc: SimContext): (Result, SimContext) = ??? @annotation.tailrec private def whenHelper( @@ -127,41 +205,44 @@ object Simulate: otherConds: List[Scope[GBoolean]], otherCaseCodes: List[Scope[?]], otherwise: Scope[?], - )(using sc: SimContext): Result = - if sim(when).asInstanceOf[Boolean] then sim(thenCode.expr) + sc: SimContext + ): (Result, SimContext) = + val (boolRes, newSc) = sim(when, sc) + if boolRes.asInstanceOf[Boolean] then sim(thenCode.expr, newSc) else otherConds.headOption match - case None => sim(otherwise.expr) - case Some(cond) => - whenHelper( - when = cond.expr, - thenCode = otherCaseCodes.head, - otherConds = otherConds.tail, - otherCaseCodes = otherCaseCodes.tail, - otherwise = otherwise, - ) - - private def simWhen(e: WhenExpr[?])(using sc: SimContext): Result = e match + case None => sim(otherwise.expr, newSc) + case Some(cond) => whenHelper( + when = cond.expr, + thenCode = otherCaseCodes.head, + otherConds = otherConds.tail, + otherCaseCodes = otherCaseCodes.tail, + otherwise = otherwise, + sc = newSc + ) + + private def simWhen(e: WhenExpr[?], sc: SimContext): (Result, SimContext) = e match case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => - whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise) + whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, sc) - private def simReadBuffer(buf: ReadBuffer[?])(using sc: SimContext): Result = buf match + private def simReadBuffer(buf: ReadBuffer[?], sc: SimContext): (Result, SimContext) = buf match case ReadBuffer(buffer, index) => - val i = sim(index).asInstanceOf[Int] - sc.addRead(buffer, i) - sc.read(buffer, i) + val (res, sc1) = sim(index.tree, sc) + val i = res.asInstanceOf[Int] + val newSc = sc1.addRead(buffer, i) + (newSc.lookupRead(buffer, i), newSc) - private def simReadUniform(uni: ReadUniform[?])(using sc: SimContext): Result = uni match + private def simReadUniform(uni: ReadUniform[?], sc: SimContext): (Result, SimContext) = uni match case ReadUniform(uniform) => ??? - private def simGArrayElem(gElem: GArrayElem[?])(using sc: SimContext): Result = gElem match + private def simGArrayElem(gElem: GArrayElem[?], sc: SimContext): (Result, SimContext) = gElem match case GArrayElem(index, i) => ??? - private def simFoldSeq(seq: FoldSeq[?, ?])(using sc: SimContext): Result = seq match + private def simFoldSeq(seq: FoldSeq[?, ?], sc: SimContext): (Result, SimContext) = seq match case FoldSeq(zero, fn, seq) => ??? - private def simComposeStruct(cs: ComposeStruct[?])(using sc: SimContext): Result = cs match + private def simComposeStruct(cs: ComposeStruct[?], sc: SimContext): (Result, SimContext) = cs match case ComposeStruct(fields, resultSchema) => ??? - private def simGetField(gf: GetField[?, ?])(using sc: SimContext): Result = gf match + private def simGetField(gf: GetField[?, ?], sc: SimContext): (Result, SimContext) = gf match case GetField(struct, fieldIndex) => ??? From cc2520dfc7b41bfaf52322bf50b6dc3de6b37950 Mon Sep 17 00:00:00 2001 From: spamegg Date: Thu, 17 Jul 2025 18:52:58 +0300 Subject: [PATCH 08/13] fix formatting --- .../e2e/interpreter/SimulateWhenTests.scala | 4 +- .../cyfra/interpreter/SimContext.scala | 7 +- .../cyfra/interpreter/Simulate.scala | 67 ++++++++++--------- 3 files changed, 42 insertions(+), 36 deletions(-) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala index 841192b8..29f46957 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -82,7 +82,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: thenCode = Scope(expr1), otherConds = List(Scope(cond2), Scope(cond3)), // false false otherCaseCodes = List(Scope(expr1), Scope(expr2)), // 384, 68 - otherwise = Scope(expr3) // 3 + otherwise = Scope(expr3), // 3 ) val (res, newSc) = Simulate.sim(expr, sc) val exp = 3 @@ -90,4 +90,4 @@ class SimulateWhenE2eTest extends munit.FunSuite: // There should be 2 reads in the simulation context assert(newSc.reads.contains(Read(buffer, 32))) - assert(newSc.reads.contains(Read(buffer, 64))) \ No newline at end of file + assert(newSc.reads.contains(Read(buffer, 64))) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala index 7e919211..fe237f3a 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -4,7 +4,12 @@ import io.computenode.cyfra.dsl.{*, given} import binding.GBuffer import Result.Result -case class SimContext(exprMap: Map[Int, Result] = Map(), bufMap: Map[GBuffer[?], Array[Result]] = Map(), reads: List[Read] = Nil, writes: List[Write] = Nil): +case class SimContext( + exprMap: Map[Int, Result] = Map(), + bufMap: Map[GBuffer[?], Array[Result]] = Map(), + reads: List[Read] = Nil, + writes: List[Write] = Nil, +): def addBuffer(buffer: GBuffer[?], array: Array[Result]): SimContext = copy(bufMap = bufMap + (buffer -> array)) def addResult(treeid: Int, result: Result): SimContext = copy(exprMap = exprMap + (treeid -> result)) def addRead(buffer: GBuffer[?], index: Int): SimContext = copy(reads = Read(buffer, index) :: reads) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index bb9434e0..7979eb5b 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -23,16 +23,16 @@ object Simulate: case Nil => ??? // should not happen def simOne(e: Expression[?], sc: SimContext): (Result, SimContext) = e match - case e: PhantomExpression[?] => simPhantom(e, sc) - case Negate(a) => + case e: PhantomExpression[?] => simPhantom(e, sc) + case Negate(a) => val (res, newSc) = simValue(a, sc) (res.negate, newSc) - case e: BinaryOpExpression[?] => simBinOp(e, sc) - case ScalarProd(a, b) => + case e: BinaryOpExpression[?] => simBinOp(e, sc) + case ScalarProd(a, b) => val (resA, scA) = simVector(a, sc) val (resB, scB) = simScalar(b, scA) (resA.scale(resB), scB) - case DotProd(a, b) => + case DotProd(a, b) => val (resA, scA) = simVector(a, sc) val (resB, scB) = simVector(b, scA) (resA.dot(resB), scB) @@ -42,29 +42,29 @@ object Simulate: val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA && resB, scB) - case Or(a, b) => + case Or(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA || resB, scB) - case Not(a) => + case Not(a) => val (res, newSc) = simScalar(a, sc) (res.negate, newSc) - case ExtractScalar(a, i) => + case ExtractScalar(a, i) => val (resA, scA) = simVector(a, sc) val (index, scI) = simValue(i, scA) (resA.apply(index.asInstanceOf[Int]), scI) - case e: ConvertExpression[?, ?] => simConvert(e, sc) - case e: Const[?] => (simConst(e), sc) - case ComposeVec2(a, b) => + case e: ConvertExpression[?, ?] => simConvert(e, sc) + case e: Const[?] => (simConst(e), sc) + case ComposeVec2(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (Vector(resA, resB), scB) - case ComposeVec3(a, b, c) => + case ComposeVec3(a, b, c) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) val (resC, scC) = simScalar(c, scB) (Vector(resA, resB, resC), scC) - case ComposeVec4(a, b, c, d) => + case ComposeVec4(a, b, c, d) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) val (resC, scC) = simScalar(c, scB) @@ -89,7 +89,7 @@ object Simulate: case AggregateElem(tid: Int) => ??? private def simBinOp(e: BinaryOpExpression[?], sc: SimContext): (Result, SimContext) = e match - case Sum(a, b) => // scalar or vector + case Sum(a, b) => // scalar or vector val (resA, scA) = simValue(a, sc) val (resB, scB) = simValue(b, scA) (resA.add(resB), scB) @@ -97,15 +97,15 @@ object Simulate: val (resA, scA) = simValue(a, sc) val (resB, scB) = simValue(b, scA) (resA.sub(resB), scB) - case Mul(a, b) => + case Mul(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.mul(resB), scB) - case Div(a, b) => + case Div(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.div(resB), scB) - case Mod(a, b) => + case Mod(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.mod(resB), scB) @@ -115,11 +115,11 @@ object Simulate: case BitwiseNot(a) => val (res, newSc) = simScalar(a, sc) (res.bitNeg, newSc) - case ShiftLeft(a, by) => + case ShiftLeft(a, by) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(by, scA) (resA.shiftLeft(resB), scB) - case ShiftRight(a, by) => + case ShiftRight(a, by) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(by, scA) (resA.shiftRight(resB), scB) @@ -129,7 +129,7 @@ object Simulate: val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.bitAnd(resB), scB) - case BitwiseOr(a, b) => + case BitwiseOr(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.bitOr(resB), scB) @@ -139,11 +139,11 @@ object Simulate: (resA.bitXor(resB), scB) private def simCompareOp(e: ComparisonOpExpression[?], sc: SimContext): (Boolean, SimContext) = e match - case GreaterThan(a, b) => + case GreaterThan(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.gt(resB), scB) - case LessThan(a, b) => + case LessThan(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.lt(resB), scB) @@ -151,11 +151,11 @@ object Simulate: val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.gteq(resB), scB) - case LessThanEqual(a, b) => + case LessThanEqual(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.lteq(resB), scB) - case Equal(a, b) => + case Equal(a, b) => val (resA, scA) = simScalar(a, sc) val (resB, scB) = simScalar(b, scA) (resA.eql(resB), scB) @@ -205,21 +205,22 @@ object Simulate: otherConds: List[Scope[GBoolean]], otherCaseCodes: List[Scope[?]], otherwise: Scope[?], - sc: SimContext + sc: SimContext, ): (Result, SimContext) = val (boolRes, newSc) = sim(when, sc) if boolRes.asInstanceOf[Boolean] then sim(thenCode.expr, newSc) else otherConds.headOption match case None => sim(otherwise.expr, newSc) - case Some(cond) => whenHelper( - when = cond.expr, - thenCode = otherCaseCodes.head, - otherConds = otherConds.tail, - otherCaseCodes = otherCaseCodes.tail, - otherwise = otherwise, - sc = newSc - ) + case Some(cond) => + whenHelper( + when = cond.expr, + thenCode = otherCaseCodes.head, + otherConds = otherConds.tail, + otherCaseCodes = otherCaseCodes.tail, + otherwise = otherwise, + sc = newSc, + ) private def simWhen(e: WhenExpr[?], sc: SimContext): (Result, SimContext) = e match case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => From 20bccb23603b9a31f25db78a8103129cd26f490c Mon Sep 17 00:00:00 2001 From: spamegg Date: Sat, 19 Jul 2025 18:12:28 +0300 Subject: [PATCH 09/13] some simplifications --- .../e2e/interpreter/InterpreterTests.scala | 7 +- .../cyfra/e2e/interpreter/SimulateTests.scala | 15 +- .../e2e/interpreter/SimulateWhenTests.scala | 8 +- .../cyfra/interpreter/Interpreter.scala | 54 ++-- .../cyfra/interpreter/SimContext.scala | 16 +- .../cyfra/interpreter/Simulate.scala | 251 ++++++------------ 6 files changed, 138 insertions(+), 213 deletions(-) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala index 7b5f57a8..27483c16 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala @@ -2,10 +2,15 @@ package io.computenode.cyfra.e2e.interpreter import io.computenode.cyfra.interpreter.*, Result.* import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* import Value.FromExpr.fromExpr, control.Scope class InterpreterE2eTest extends munit.FunSuite: - test("stub"): + test("interpret should not stack overflow"): + val pure = Pure(0) + var gio = FlatMap(pure, pure) + for _ <- 0 until 1000000 do gio = FlatMap(pure, gio) + val result = Interpreter.interpret(gio, 0) val res = 0 val exp = 0 assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index 531adc1e..27100174 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -18,29 +18,30 @@ class SimulateE2eTest extends munit.FunSuite: val e3 = Mul(f, fromExpr(e2)) val e4 = Div(fromExpr(e3), d) val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4) - val (result, _) = Simulate.sim(expr) + val (result, _) = Simulate.sim(expr, SimContext()) val expected = 2 assert(result == expected, s"Expected $expected, got $result") test("simulate vec4, scalar, dot, extract scalar"): val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) - val (res1, _) = Simulate.sim(v1) + val (res1, _) = Simulate.sim(v1, SimContext()) val exp1 = Vector(1f, 2f, 3f, 4f) assert(res1 == exp1, s"Expected $exp1, got $res1") val i: Int32 = 2 - val (res, _) = Simulate.sim(ExtractScalar(fromExpr(v1), i)) + val expr = ExtractScalar(fromExpr(v1), i) + val (res, _) = Simulate.sim(expr, SimContext()) val exp = 3f assert(res == exp, s"Expected $exp, got $res") val v2 = ScalarProd(fromExpr(v1), -1f) - val (res2, _) = Simulate.sim(v2) + val (res2, _) = Simulate.sim(v2, SimContext()) val exp2 = Vector(-1f, -2f, -3f, -4f) assert(res2 == exp2, s"Expected $exp2, got $res2") val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f) val dot = DotProd(fromExpr(v1), fromExpr(v3)) - val (res3a, _) = Simulate.sim(dot) + val (res3a, _) = Simulate.sim(dot, SimContext()) val res3 = res3a.asInstanceOf[Float] val exp3 = 0f assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3") @@ -55,7 +56,7 @@ class SimulateE2eTest extends munit.FunSuite: val or = BitwiseOr(fromExpr(left), fromExpr(right)) val xor = BitwiseXor(fromExpr(and), fromExpr(or)) - val (res, _) = Simulate.sim(xor) + val (res, _) = Simulate.sim(xor, SimContext()) val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) assert(res == exp, s"Expected $exp, got $res") @@ -63,7 +64,7 @@ class SimulateE2eTest extends munit.FunSuite: val a: Int32 = 1 var sum = Sum(a, a) // 2 for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) - val (res, _) = Simulate.sim(sum) + val (res, _) = Simulate.sim(sum, SimContext()) val exp = 1000002 assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala index 29f46957..bf3ee530 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -14,7 +14,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val (res1, _) = Simulate.sim(expr1) + val (res1, _) = Simulate.sim(expr1, SimContext()) val exp1 = 1 assert(res1 == exp1, s"Expected $exp1, got $res1") @@ -26,7 +26,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val (res2, _) = Simulate.sim(expr2) + val (res2, _) = Simulate.sim(expr2, SimContext()) val exp2 = 2 assert(res2 == exp2, s"Expected $exp2, got $res2") @@ -38,7 +38,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val (res3, _) = Simulate.sim(expr3) + val (res3, _) = Simulate.sim(expr3, SimContext()) val exp3 = 4 assert(res3 == exp3, s"Expected $exp3, got $res3") @@ -50,7 +50,7 @@ class SimulateWhenE2eTest extends munit.FunSuite: otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), otherwise = Scope(ConstInt32(3)), ) - val (res4, _) = Simulate.sim(expr4) + val (res4, _) = Simulate.sim(expr4, SimContext()) val exp4 = 3 assert(res4 == exp4, s"Expected $exp4, got $res4") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 12a22ff9..354f60bb 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -2,10 +2,9 @@ package io.computenode.cyfra.interpreter import io.computenode.cyfra.dsl.{*, given} import binding.*, Value.*, gio.GIO, GIO.* +import Result.Result import izumi.reflect.Tag -case class Write(buffer: GBuffer[?], index: Int, value: Any) - case class InvocResult( invocId: Int, instructions: List[Expression[?]] = Nil, @@ -18,40 +17,39 @@ case class InterpretResult(invocs: List[InvocResult] = Nil): def add(that: InterpretResult) = InterpretResult(that.invocs ::: invocs) object Interpreter: - def interpretPure(gio: Pure[?], invocId: Int)(using sc: SimContext): InterpretResult = gio match - case Pure(value) => - InterpretResult: - List(InvocResult(invocId, ???, List(value), ???, sc.reads)) + def interpretPure(gio: Pure[?], invocId: Int, sc: SimContext): InvocResult = gio match + case Pure(value) => InvocResult(invocId, ???, List(value), ???, ???) - def interpretWriteBuffer(gio: WriteBuffer[?], invocId: Int)(using sc: SimContext): InterpretResult = gio match + def interpretWriteBuffer(gio: WriteBuffer[?], invocId: Int, sc: SimContext): InvocResult = gio match case WriteBuffer(buffer, index, value) => - val i = Simulate.sim(index).asInstanceOf[Int] // reads happening here? - val res = Simulate.sim(value) // reads happening here? (has its own SimContext) - // sc.bufMap(buffer)(i) = res - val writes = List(Write(buffer, i, res)) - val reads = ??? // get all the reads from the SimContext? - InterpretResult: - List(InvocResult(invocId, ???, List(value), writes, reads)) - - def interpretWriteUniform(gio: WriteUniform[?], invocId: Int)(using sc: SimContext): InterpretResult = gio match + val (n, _) = Simulate.sim(index) + val i = n.asInstanceOf[Int] + val (res, _) = Simulate.sim(value) + val newSc = sc.addWrite(buffer, i, res) + InvocResult(invocId, ???, List(value), ???, ???) + + def interpretWriteUniform(gio: WriteUniform[?], invocId: Int, sc: SimContext): InvocResult = gio match case WriteUniform(uniform, value) => ??? - def interpretOne(gio: GIO[?])(using SimContext): InterpretResult = - val invocId = Simulate.sim(invocationId).asInstanceOf[Int] + def interpretOne(gio: GIO[?], invocId: Int, sc: SimContext): InvocResult = + // val invocId = Simulate.sim(invocationId, sc, Map()).asInstanceOf[Int] gio match - case p: Pure[?] => interpretPure(p, invocId) - case wb: WriteBuffer[?] => interpretWriteBuffer(wb, invocId) - case wu: WriteUniform[?] => interpretWriteUniform(wu, invocId) + case p: Pure[?] => interpretPure(p, invocId, sc) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, invocId, sc) + case wu: WriteUniform[?] => interpretWriteUniform(wu, invocId, sc) case _ => throw IllegalArgumentException("interpretOne: invalid GIO") @annotation.tailrec - def interpretMany(gios: List[GIO[?]], res: InterpretResult)(using SimContext): InterpretResult = gios match - case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, res) + def interpretMany(gios: List[GIO[?]], invocId: Int, res: InvocResult, sc: SimContext): InvocResult = gios match + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, invocId, res, sc) case Repeat(n, f) :: tail => - val int = Simulate.sim(n).asInstanceOf[Int] + val (i, _) = Simulate.sim(n) + val int = i.asInstanceOf[Int] val newGios = (0 until int).map(i => f(i)).toList - interpretMany(newGios ::: tail, res) - case head :: tail => interpretMany(tail, res.add(interpretOne(head))) - case Nil => res + interpretMany(newGios ::: tail, invocId, res, sc) + case head :: tail => + val newRes = interpretOne(head, invocId, sc) + interpretMany(tail, invocId, ???, sc) + case Nil => res - def interpret(gio: GIO[?])(using SimContext): InterpretResult = interpretMany(List(gio), InterpretResult()) + def interpret(gio: GIO[?], invocId: Int): InvocResult = interpretMany(List(gio), invocId, InvocResult(invocId), SimContext()) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala index fe237f3a..c9b98201 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -4,16 +4,14 @@ import io.computenode.cyfra.dsl.{*, given} import binding.GBuffer import Result.Result -case class SimContext( - exprMap: Map[Int, Result] = Map(), - bufMap: Map[GBuffer[?], Array[Result]] = Map(), - reads: List[Read] = Nil, - writes: List[Write] = Nil, -): +case class SimContext(bufMap: Map[GBuffer[?], Array[Result]] = Map(), reads: List[Read] = Nil, writes: List[Write] = Nil): def addBuffer(buffer: GBuffer[?], array: Array[Result]): SimContext = copy(bufMap = bufMap + (buffer -> array)) - def addResult(treeid: Int, result: Result): SimContext = copy(exprMap = exprMap + (treeid -> result)) def addRead(buffer: GBuffer[?], index: Int): SimContext = copy(reads = Read(buffer, index) :: reads) - def lookupRead(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) - def lookupExpr(treeid: Int): Result = exprMap(treeid) + def addWrite(buffer: GBuffer[?], index: Int, value: Result): SimContext = + val newArray = bufMap(buffer).updated(index, value) + val newWrites = Write(buffer, index, value) :: writes + copy(bufMap = bufMap.updated(buffer, newArray), writes = newWrites) + def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) case class Read(buffer: GBuffer[?], index: Int) +case class Write(buffer: GBuffer[?], index: Int, value: Result) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 7979eb5b..9a7c2cc8 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -10,190 +10,112 @@ import collection.mutable.Map as MMap object Simulate: import Result.* - def sim(v: Value): (Result, SimContext) = sim(v.tree) // helpful wrapper for Value instead of Expression - def sim(e: Expression[?], sc: SimContext = SimContext()): (Result, SimContext) = simIterate(buildBlock(e), sc) + def sim(v: Value): (Result, SimContext) = sim(v.tree) + def sim(e: Expression[?], sc: SimContext = SimContext()): (Result, SimContext) = simIterate(buildBlock(e), sc)(using Map()) @annotation.tailrec - def simIterate(blocks: List[Expression[?]], sc: SimContext): (Result, SimContext) = blocks match + def simIterate(blocks: List[Expression[?]], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = blocks match case head :: Nil => simOne(head, sc) case head :: next => - val (result, sc1) = simOne(head, sc) - val newSc = sc1.addResult(head.treeid, result) - simIterate(next, newSc) + val (result, newSc) = simOne(head, sc) // context updated if there are reads/writes + val newExprMap = exprMap + (head.treeid -> result) // update map with new result + simIterate(next, newSc)(using newExprMap) case Nil => ??? // should not happen - def simOne(e: Expression[?], sc: SimContext): (Result, SimContext) = e match - case e: PhantomExpression[?] => simPhantom(e, sc) - case Negate(a) => - val (res, newSc) = simValue(a, sc) - (res.negate, newSc) - case e: BinaryOpExpression[?] => simBinOp(e, sc) - case ScalarProd(a, b) => - val (resA, scA) = simVector(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.scale(resB), scB) - case DotProd(a, b) => - val (resA, scA) = simVector(a, sc) - val (resB, scB) = simVector(b, scA) - (resA.dot(resB), scB) - case e: BitwiseOpExpression[?] => simBitwiseOp(e, sc) - case e: ComparisonOpExpression[?] => simCompareOp(e, sc) - case And(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA && resB, scB) - case Or(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA || resB, scB) - case Not(a) => - val (res, newSc) = simScalar(a, sc) - (res.negate, newSc) - case ExtractScalar(a, i) => - val (resA, scA) = simVector(a, sc) - val (index, scI) = simValue(i, scA) - (resA.apply(index.asInstanceOf[Int]), scI) - case e: ConvertExpression[?, ?] => simConvert(e, sc) - case e: Const[?] => (simConst(e), sc) - case ComposeVec2(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (Vector(resA, resB), scB) - case ComposeVec3(a, b, c) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - val (resC, scC) = simScalar(c, scB) - (Vector(resA, resB, resC), scC) - case ComposeVec4(a, b, c, d) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - val (resC, scC) = simScalar(c, scB) - val (resD, scD) = simScalar(d, scC) - (Vector(resA, resB, resC, resD), scD) + def simOne(e: Expression[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = e match + case e: PhantomExpression[?] => (simPhantom(e), sc) + case Negate(a) => (simValue(a).negate, sc) + case e: BinaryOpExpression[?] => (simBinOp(e), sc) + case ScalarProd(a, b) => (simVector(a).scale(simScalar(b)), sc) + case DotProd(a, b) => (simVector(a).dot(simVector(b)), sc) + case e: BitwiseOpExpression[?] => (simBitwiseOp(e), sc) + case e: ComparisonOpExpression[?] => (simCompareOp(e), sc) + case And(a, b) => (simScalar(a) && simScalar(b), sc) + case Or(a, b) => (simScalar(a) || simScalar(b), sc) + case Not(a) => (simScalar(a).negate, sc) + case ExtractScalar(a, i) => (simVector(a).apply(simValue(i).asInstanceOf[Int]), sc) + case e: ConvertExpression[?, ?] => (simConvert(e), sc) + case e: Const[?] => (simConst(e), sc) + case ComposeVec2(a, b) => (Vector(simScalar(a), simScalar(b)), sc) + case ComposeVec3(a, b, c) => (Vector(simScalar(a), simScalar(b), simScalar(c)), sc) + case ComposeVec4(a, b, c, d) => (Vector(simScalar(a), simScalar(b), simScalar(c), simScalar(d)), sc) case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue), sc) case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue), sc) case InvocationId => ??? case Pass(value) => ??? case Dynamic(source) => ??? - case e: WhenExpr[?] => simWhen(e, sc) - case e: ReadBuffer[?] => simReadBuffer(e, sc) - case e: ReadUniform[?] => simReadUniform(e, sc) - case e: GArrayElem[?] => simGArrayElem(e, sc) - case e: FoldSeq[?, ?] => simFoldSeq(e, sc) - case e: ComposeStruct[?] => simComposeStruct(e, sc) - case e: GetField[?, ?] => simGetField(e, sc) + case e: WhenExpr[?] => simWhen(e, sc) // returns new SimContext + case e: ReadBuffer[?] => simReadBuffer(e, sc) // returns new SimContext + case e: ReadUniform[?] => simReadUniform(e) + case e: GArrayElem[?] => simGArrayElem(e) + case e: FoldSeq[?, ?] => simFoldSeq(e) + case e: ComposeStruct[?] => simComposeStruct(e) + case e: GetField[?, ?] => simGetField(e) case _ => throw IllegalArgumentException("sim: wrong argument") - private def simPhantom(e: PhantomExpression[?], sc: SimContext): (Result, SimContext) = e match + private def simPhantom(e: PhantomExpression[?]): Result = e match case CurrentElem(tid: Int) => ??? case AggregateElem(tid: Int) => ??? - private def simBinOp(e: BinaryOpExpression[?], sc: SimContext): (Result, SimContext) = e match - case Sum(a, b) => // scalar or vector - val (resA, scA) = simValue(a, sc) - val (resB, scB) = simValue(b, scA) - (resA.add(resB), scB) - case Diff(a, b) => // scalar or vector - val (resA, scA) = simValue(a, sc) - val (resB, scB) = simValue(b, scA) - (resA.sub(resB), scB) - case Mul(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.mul(resB), scB) - case Div(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.div(resB), scB) - case Mod(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.mod(resB), scB) - - private def simBitwiseOp(e: BitwiseOpExpression[?], sc: SimContext): (Int, SimContext) = e match - case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e, sc) - case BitwiseNot(a) => - val (res, newSc) = simScalar(a, sc) - (res.bitNeg, newSc) - case ShiftLeft(a, by) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(by, scA) - (resA.shiftLeft(resB), scB) - case ShiftRight(a, by) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(by, scA) - (resA.shiftRight(resB), scB) - - private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?], sc: SimContext): (Int, SimContext) = e match - case BitwiseAnd(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.bitAnd(resB), scB) - case BitwiseOr(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.bitOr(resB), scB) - case BitwiseXor(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.bitXor(resB), scB) - - private def simCompareOp(e: ComparisonOpExpression[?], sc: SimContext): (Boolean, SimContext) = e match - case GreaterThan(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.gt(resB), scB) - case LessThan(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.lt(resB), scB) - case GreaterThanEqual(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.gteq(resB), scB) - case LessThanEqual(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.lteq(resB), scB) - case Equal(a, b) => - val (resA, scA) = simScalar(a, sc) - val (resB, scB) = simScalar(b, scA) - (resA.eql(resB), scB) - - private def simConvert(e: ConvertExpression[?, ?], sc: SimContext): (Float | Int, SimContext) = e match + private def simBinOp(e: BinaryOpExpression[?])(using exprMap: Map[Int, Result]): Result = e match + case Sum(a, b) => simValue(a).add(simValue(b)) // scalar or vector + case Diff(a, b) => simValue(a).sub(simValue(b)) // scalar or vector + case Mul(a, b) => simScalar(a).mul(simScalar(b)) + case Div(a, b) => simScalar(a).div(simScalar(b)) + case Mod(a, b) => simScalar(a).mod(simScalar(b)) + + private def simBitwiseOp(e: BitwiseOpExpression[?])(using exprMap: Map[Int, Result]): Int = e match + case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) + case BitwiseNot(a) => simScalar(a).bitNeg + case ShiftLeft(a, by) => simScalar(a).shiftLeft(simScalar(by)) + case ShiftRight(a, by) => simScalar(a).shiftRight(simScalar(by)) + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using exprMap: Map[Int, Result]): Int = e match + case BitwiseAnd(a, b) => simScalar(a).bitAnd(simScalar(b)) + case BitwiseOr(a, b) => simScalar(a).bitOr(simScalar(b)) + case BitwiseXor(a, b) => simScalar(a).bitXor(simScalar(b)) + + private def simCompareOp(e: ComparisonOpExpression[?])(using exprMap: Map[Int, Result]): Boolean = e match + case GreaterThan(a, b) => simScalar(a).gt(simScalar(b)) + case LessThan(a, b) => simScalar(a).lt(simScalar(b)) + case GreaterThanEqual(a, b) => simScalar(a).gteq(simScalar(b)) + case LessThanEqual(a, b) => simScalar(a).lteq(simScalar(b)) + case Equal(a, b) => simScalar(a).eql(simScalar(b)) + + private def simConvert(e: ConvertExpression[?, ?])(using exprMap: Map[Int, Result]): Float | Int = e match case ToFloat32(a) => - sc.lookupExpr(a.treeid) match - case f: Float => (f, sc) + exprMap(a.treeid) match + case f: Float => f case _ => throw IllegalArgumentException("ToFloat32: wrong argument type") case ToInt32(a) => - sc.lookupExpr(a.treeid) match - case n: Int => (n, sc) + exprMap(a.treeid) match + case n: Int => n case _ => throw IllegalArgumentException("ToInt32: wrong argument type") case ToUInt32(a) => - sc.lookupExpr(a.treeid) match - case n: Int => (n, sc) + exprMap(a.treeid) match + case n: Int => n case _ => throw IllegalArgumentException("ToUInt32: wrong argument type") - private def simConst(e: Const[?]): ScalarRes = e match // no context needed + private def simConst(e: Const[?]): ScalarRes = e match case ConstFloat32(value) => value case ConstInt32(value) => value case ConstUInt32(value) => value case ConstGB(value) => value - private def simValue(v: Value, sc: SimContext): (Result, SimContext) = v match - case v: Scalar => simScalar(v, sc) - case v: Vec[?] => simVector(v, sc) + private def simValue(v: Value)(using exprMap: Map[Int, Result]): Result = v match + case v: Scalar => simScalar(v) + case v: Vec[?] => simVector(v) - private def simScalar(v: Scalar, sc: SimContext): (ScalarRes, SimContext) = v match - case v: FloatType => (sc.lookupExpr(v.tree.treeid).asInstanceOf[Float], sc) - case v: IntType => (sc.lookupExpr(v.tree.treeid).asInstanceOf[Int], sc) - case v: UIntType => (sc.lookupExpr(v.tree.treeid).asInstanceOf[Int], sc) - case GBoolean(source) => (sc.lookupExpr(source.treeid).asInstanceOf[Boolean], sc) + private def simScalar(v: Scalar)(using exprMap: Map[Int, Result]): ScalarRes = v match + case v: FloatType => exprMap(v.tree.treeid).asInstanceOf[Float] + case v: IntType => exprMap(v.tree.treeid).asInstanceOf[Int] + case v: UIntType => exprMap(v.tree.treeid).asInstanceOf[Int] + case GBoolean(source) => exprMap(source.treeid).asInstanceOf[Boolean] - private def simVector(v: Vec[?], sc: SimContext): (Vector[ScalarRes], SimContext) = v match - case Vec2(tree) => (sc.lookupExpr(tree.treeid).asInstanceOf[Vector[ScalarRes]], sc) - case Vec3(tree) => (sc.lookupExpr(tree.treeid).asInstanceOf[Vector[ScalarRes]], sc) - case Vec4(tree) => (sc.lookupExpr(tree.treeid).asInstanceOf[Vector[ScalarRes]], sc) + private def simVector(v: Vec[?])(using exprMap: Map[Int, Result]): Vector[ScalarRes] = v match + case Vec2(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec3(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] + case Vec4(tree) => exprMap(tree.treeid).asInstanceOf[Vector[ScalarRes]] private def simExtFunc(fn: FunctionName, args: List[Result], sc: SimContext): (Result, SimContext) = ??? private def simFunc(fn: FnIdentifier, body: Result, args: List[Result], sc: SimContext): (Result, SimContext) = ??? @@ -207,6 +129,8 @@ object Simulate: otherwise: Scope[?], sc: SimContext, ): (Result, SimContext) = + // scopes are not included in the "main" exprMap, they have to be simulated from scratch. + // there could be reads/writes happening in scopes, SimContext has to be updated. val (boolRes, newSc) = sim(when, sc) if boolRes.asInstanceOf[Boolean] then sim(thenCode.expr, newSc) else @@ -226,24 +150,23 @@ object Simulate: case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, sc) - private def simReadBuffer(buf: ReadBuffer[?], sc: SimContext): (Result, SimContext) = buf match + private def simReadBuffer(buf: ReadBuffer[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = buf match case ReadBuffer(buffer, index) => - val (res, sc1) = sim(index.tree, sc) - val i = res.asInstanceOf[Int] - val newSc = sc1.addRead(buffer, i) - (newSc.lookupRead(buffer, i), newSc) + val i = exprMap(index.tree.treeid).asInstanceOf[Int] + val newSc = sc.addRead(buffer, i) + (newSc.lookup(buffer, i), newSc) - private def simReadUniform(uni: ReadUniform[?], sc: SimContext): (Result, SimContext) = uni match + private def simReadUniform(uni: ReadUniform[?]): (Result, SimContext) = uni match case ReadUniform(uniform) => ??? - private def simGArrayElem(gElem: GArrayElem[?], sc: SimContext): (Result, SimContext) = gElem match + private def simGArrayElem(gElem: GArrayElem[?]): (Result, SimContext) = gElem match case GArrayElem(index, i) => ??? - private def simFoldSeq(seq: FoldSeq[?, ?], sc: SimContext): (Result, SimContext) = seq match + private def simFoldSeq(seq: FoldSeq[?, ?]): (Result, SimContext) = seq match case FoldSeq(zero, fn, seq) => ??? - private def simComposeStruct(cs: ComposeStruct[?], sc: SimContext): (Result, SimContext) = cs match + private def simComposeStruct(cs: ComposeStruct[?]): (Result, SimContext) = cs match case ComposeStruct(fields, resultSchema) => ??? - private def simGetField(gf: GetField[?, ?], sc: SimContext): (Result, SimContext) = gf match + private def simGetField(gf: GetField[?, ?]): (Result, SimContext) = gf match case GetField(struct, fieldIndex) => ??? From d7378c087576ac4da478c77cd8617cc9f859f8a1 Mon Sep 17 00:00:00 2001 From: spamegg Date: Sat, 19 Jul 2025 18:40:02 +0300 Subject: [PATCH 10/13] work on interpreter --- .../cyfra/interpreter/Interpreter.scala | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 354f60bb..3fcd3128 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -11,36 +11,41 @@ case class InvocResult( values: List[Any] = Nil, writes: List[Write] = Nil, reads: List[Read] = Nil, -) +): + def merge(that: InvocResult) = InvocResult( + invocId = invocId, + instructions = instructions ::: that.instructions, + values = values ::: that.values, + writes = writes ::: that.writes, + reads = reads ::: that.reads, + ) case class InterpretResult(invocs: List[InvocResult] = Nil): def add(that: InterpretResult) = InterpretResult(that.invocs ::: invocs) object Interpreter: - def interpretPure(gio: Pure[?], invocId: Int, sc: SimContext): InvocResult = gio match - case Pure(value) => InvocResult(invocId, ???, List(value), ???, ???) + private def interpretPure(gio: Pure[?], invocId: Int, sc: SimContext): InvocResult = gio match + case Pure(value) => InvocResult(invocId, ???, List(value), sc.writes, sc.reads) - def interpretWriteBuffer(gio: WriteBuffer[?], invocId: Int, sc: SimContext): InvocResult = gio match + private def interpretWriteBuffer(gio: WriteBuffer[?], invocId: Int, sc: SimContext): InvocResult = gio match case WriteBuffer(buffer, index, value) => val (n, _) = Simulate.sim(index) val i = n.asInstanceOf[Int] val (res, _) = Simulate.sim(value) - val newSc = sc.addWrite(buffer, i, res) - InvocResult(invocId, ???, List(value), ???, ???) + val newSc = sc.addWrite(buffer, i, res) // should we keep this around? + InvocResult(invocId, ???, List(value), newSc.writes, newSc.reads) - def interpretWriteUniform(gio: WriteUniform[?], invocId: Int, sc: SimContext): InvocResult = gio match + private def interpretWriteUniform(gio: WriteUniform[?], invocId: Int, sc: SimContext): InvocResult = gio match case WriteUniform(uniform, value) => ??? - def interpretOne(gio: GIO[?], invocId: Int, sc: SimContext): InvocResult = - // val invocId = Simulate.sim(invocationId, sc, Map()).asInstanceOf[Int] - gio match - case p: Pure[?] => interpretPure(p, invocId, sc) - case wb: WriteBuffer[?] => interpretWriteBuffer(wb, invocId, sc) - case wu: WriteUniform[?] => interpretWriteUniform(wu, invocId, sc) - case _ => throw IllegalArgumentException("interpretOne: invalid GIO") + private def interpretOne(gio: GIO[?], invocId: Int, sc: SimContext): InvocResult = gio match + case p: Pure[?] => interpretPure(p, invocId, sc) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, invocId, sc) + case wu: WriteUniform[?] => interpretWriteUniform(wu, invocId, sc) + case _ => throw IllegalArgumentException("interpretOne: invalid GIO") @annotation.tailrec - def interpretMany(gios: List[GIO[?]], invocId: Int, res: InvocResult, sc: SimContext): InvocResult = gios match + private def interpretMany(gios: List[GIO[?]], invocId: Int, res: InvocResult, sc: SimContext): InvocResult = gios match case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, invocId, res, sc) case Repeat(n, f) :: tail => val (i, _) = Simulate.sim(n) @@ -48,8 +53,9 @@ object Interpreter: val newGios = (0 until int).map(i => f(i)).toList interpretMany(newGios ::: tail, invocId, res, sc) case head :: tail => - val newRes = interpretOne(head, invocId, sc) - interpretMany(tail, invocId, ???, sc) + val newRes = interpretOne(head, invocId, sc) // should we get the updated SimContext? + interpretMany(tail, invocId, res.merge(newRes), sc) case Nil => res def interpret(gio: GIO[?], invocId: Int): InvocResult = interpretMany(List(gio), invocId, InvocResult(invocId), SimContext()) + def interpret(gio: GIO[?], invocIds: List[Int]): InterpretResult = InterpretResult(invocIds.map(interpret(gio, _))) From 9cc655bb7902259c1960621a4a9dd558ba042179 Mon Sep 17 00:00:00 2001 From: spamegg Date: Sun, 20 Jul 2025 13:29:58 +0300 Subject: [PATCH 11/13] change SimContext design a bit --- .../cyfra/e2e/interpreter/SimulateTests.scala | 2 +- .../e2e/interpreter/SimulateWhenTests.scala | 4 +- .../cyfra/interpreter/Interpreter.scala | 64 +++++++++---------- .../cyfra/interpreter/SimContext.scala | 40 +++++++++--- .../cyfra/interpreter/Simulate.scala | 4 +- 5 files changed, 64 insertions(+), 50 deletions(-) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala index 27100174..c561c59c 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -82,4 +82,4 @@ class SimulateE2eTest extends munit.FunSuite: assert(res == exp, s"Expected $exp, got $res") // the context should keep track of the read - assert(newSc.reads.contains(Read(buffer, 128)), "missing read") + assert(newSc.reads.contains(ReadBuf(buffer, 128)), "missing read") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala index bf3ee530..0026ebf5 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -89,5 +89,5 @@ class SimulateWhenE2eTest extends munit.FunSuite: assert(res == exp, s"Expected $exp, got $res") // There should be 2 reads in the simulation context - assert(newSc.reads.contains(Read(buffer, 32))) - assert(newSc.reads.contains(Read(buffer, 64))) + assert(newSc.reads.contains(ReadBuf(buffer, 32))) + assert(newSc.reads.contains(ReadBuf(buffer, 64))) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 3fcd3128..83947ae5 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -8,54 +8,48 @@ import izumi.reflect.Tag case class InvocResult( invocId: Int, instructions: List[Expression[?]] = Nil, - values: List[Any] = Nil, - writes: List[Write] = Nil, - reads: List[Read] = Nil, -): - def merge(that: InvocResult) = InvocResult( - invocId = invocId, - instructions = instructions ::: that.instructions, - values = values ::: that.values, - writes = writes ::: that.writes, - reads = reads ::: that.reads, - ) - -case class InterpretResult(invocs: List[InvocResult] = Nil): - def add(that: InterpretResult) = InterpretResult(that.invocs ::: invocs) + values: List[Result] = Nil, + writes: List[Writes] = Nil, + reads: List[Reads] = Nil, +) + +case class InterpretResult(invocs: List[InvocResult] = Nil) object Interpreter: - private def interpretPure(gio: Pure[?], invocId: Int, sc: SimContext): InvocResult = gio match - case Pure(value) => InvocResult(invocId, ???, List(value), sc.writes, sc.reads) + private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match + case Pure(value) => sc - private def interpretWriteBuffer(gio: WriteBuffer[?], invocId: Int, sc: SimContext): InvocResult = gio match + private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match case WriteBuffer(buffer, index, value) => - val (n, _) = Simulate.sim(index) + val (n, _) = Simulate.sim(index, SimContext()) // Int32, no reads/writes here, don't need resulting context val i = n.asInstanceOf[Int] - val (res, _) = Simulate.sim(value) - val newSc = sc.addWrite(buffer, i, res) // should we keep this around? - InvocResult(invocId, ???, List(value), newSc.writes, newSc.reads) + val (res, sc1) = Simulate.sim(value, sc) + sc1.addWrite(WriteBuf(buffer, i, res)) - private def interpretWriteUniform(gio: WriteUniform[?], invocId: Int, sc: SimContext): InvocResult = gio match - case WriteUniform(uniform, value) => ??? + private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match + case WriteUniform(uniform, value) => ??? // simulate value, then sc.addWrite(WriteUni...) - private def interpretOne(gio: GIO[?], invocId: Int, sc: SimContext): InvocResult = gio match - case p: Pure[?] => interpretPure(p, invocId, sc) - case wb: WriteBuffer[?] => interpretWriteBuffer(wb, invocId, sc) - case wu: WriteUniform[?] => interpretWriteUniform(wu, invocId, sc) + private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match + case p: Pure[?] => interpretPure(p, sc) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, sc) + case wu: WriteUniform[?] => interpretWriteUniform(wu, sc) case _ => throw IllegalArgumentException("interpretOne: invalid GIO") @annotation.tailrec - private def interpretMany(gios: List[GIO[?]], invocId: Int, res: InvocResult, sc: SimContext): InvocResult = gios match - case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, invocId, res, sc) + private def interpretMany(gios: List[GIO[?]], sc: SimContext): SimContext = gios match + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, sc) case Repeat(n, f) :: tail => - val (i, _) = Simulate.sim(n) + val (i, _) = Simulate.sim(n, SimContext()) // just Int32, no reads/writes val int = i.asInstanceOf[Int] val newGios = (0 until int).map(i => f(i)).toList - interpretMany(newGios ::: tail, invocId, res, sc) + interpretMany(newGios ::: tail, sc) case head :: tail => - val newRes = interpretOne(head, invocId, sc) // should we get the updated SimContext? - interpretMany(tail, invocId, res.merge(newRes), sc) - case Nil => res + val newSc = interpretOne(head, sc) + interpretMany(tail, newSc) + case Nil => sc + + def interpret(gio: GIO[?], invocId: Int): InvocResult = + val sc = interpretMany(List(gio), SimContext()) + InvocResult(invocId, ???, sc.values, sc.writes, sc.reads) - def interpret(gio: GIO[?], invocId: Int): InvocResult = interpretMany(List(gio), invocId, InvocResult(invocId), SimContext()) def interpret(gio: GIO[?], invocIds: List[Int]): InterpretResult = InterpretResult(invocIds.map(interpret(gio, _))) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala index c9b98201..17cbf874 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -1,17 +1,37 @@ package io.computenode.cyfra.interpreter import io.computenode.cyfra.dsl.{*, given} -import binding.GBuffer +import binding.{GBuffer, GUniform} import Result.Result -case class SimContext(bufMap: Map[GBuffer[?], Array[Result]] = Map(), reads: List[Read] = Nil, writes: List[Write] = Nil): +enum Reads: + case ReadBuf(buffer: GBuffer[?], index: Int) + case ReadUni(uniform: GUniform[?]) +export Reads.* + +enum Writes: + case WriteBuf(buffer: GBuffer[?], index: Int, value: Result) + case WriteUni(uni: GUniform[?], value: Result) +export Writes.* + +case class SimContext( + bufMap: Map[GBuffer[?], Array[Result]] = Map(), + values: List[Result] = Nil, + writes: List[Writes] = Nil, + reads: List[Reads] = Nil, +): def addBuffer(buffer: GBuffer[?], array: Array[Result]): SimContext = copy(bufMap = bufMap + (buffer -> array)) - def addRead(buffer: GBuffer[?], index: Int): SimContext = copy(reads = Read(buffer, index) :: reads) - def addWrite(buffer: GBuffer[?], index: Int, value: Result): SimContext = - val newArray = bufMap(buffer).updated(index, value) - val newWrites = Write(buffer, index, value) :: writes - copy(bufMap = bufMap.updated(buffer, newArray), writes = newWrites) - def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) -case class Read(buffer: GBuffer[?], index: Int) -case class Write(buffer: GBuffer[?], index: Int, value: Result) + def addRead(read: Reads): SimContext = read match + case ReadBuf(buffer, index) => copy(reads = ReadBuf(buffer, index) :: reads) + case ReadUni(uniform) => ??? + + def addWrite(write: Writes): SimContext = write match + case WriteBuf(buffer, index, value) => + val newArray = bufMap(buffer).updated(index, value) + val newWrites = WriteBuf(buffer, index, value) :: writes + copy(bufMap = bufMap.updated(buffer, newArray), writes = newWrites) + case WriteUni(uni, value) => ??? + + def addResult(res: Result) = copy(values = res :: values) + def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 9a7c2cc8..327a5a2d 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -10,7 +10,7 @@ import collection.mutable.Map as MMap object Simulate: import Result.* - def sim(v: Value): (Result, SimContext) = sim(v.tree) + def sim(v: Value, sc: SimContext): (Result, SimContext) = sim(v.tree, sc) def sim(e: Expression[?], sc: SimContext = SimContext()): (Result, SimContext) = simIterate(buildBlock(e), sc)(using Map()) @annotation.tailrec @@ -153,7 +153,7 @@ object Simulate: private def simReadBuffer(buf: ReadBuffer[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = buf match case ReadBuffer(buffer, index) => val i = exprMap(index.tree.treeid).asInstanceOf[Int] - val newSc = sc.addRead(buffer, i) + val newSc = sc.addRead(ReadBuf(buffer, i)) (newSc.lookup(buffer, i), newSc) private def simReadUniform(uni: ReadUniform[?]): (Result, SimContext) = uni match From 81646120730267df994fba40bf7a5348fee2beb4 Mon Sep 17 00:00:00 2001 From: spamegg Date: Mon, 21 Jul 2025 15:56:02 +0300 Subject: [PATCH 12/13] add uniform read/write --- .../computenode/cyfra/interpreter/Interpreter.scala | 12 ++++++++---- .../computenode/cyfra/interpreter/SimContext.scala | 8 ++++++-- .../io/computenode/cyfra/interpreter/Simulate.scala | 10 ++++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala index 83947ae5..68fcb8af 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -17,17 +17,21 @@ case class InterpretResult(invocs: List[InvocResult] = Nil) object Interpreter: private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match - case Pure(value) => sc + case Pure(value) => + val (result, newSc) = Simulate.sim(value.asInstanceOf[Value], sc) // TODO needs fixing + newSc.addResult(result) private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match case WriteBuffer(buffer, index, value) => val (n, _) = Simulate.sim(index, SimContext()) // Int32, no reads/writes here, don't need resulting context val i = n.asInstanceOf[Int] - val (res, sc1) = Simulate.sim(value, sc) - sc1.addWrite(WriteBuf(buffer, i, res)) + val (res, newSc) = Simulate.sim(value, sc) + newSc.addWrite(WriteBuf(buffer, i, res)) private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match - case WriteUniform(uniform, value) => ??? // simulate value, then sc.addWrite(WriteUni...) + case WriteUniform(uniform, value) => + val (result, newSc) = Simulate.sim(value.asInstanceOf[Value], sc) // TODO needs fixing + newSc.addWrite(WriteUni(uniform, result)) private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match case p: Pure[?] => interpretPure(p, sc) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala index 17cbf874..33917f94 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -16,6 +16,7 @@ export Writes.* case class SimContext( bufMap: Map[GBuffer[?], Array[Result]] = Map(), + uniMap: Map[GUniform[?], Result] = Map(), values: List[Result] = Nil, writes: List[Writes] = Nil, reads: List[Reads] = Nil, @@ -24,14 +25,17 @@ case class SimContext( def addRead(read: Reads): SimContext = read match case ReadBuf(buffer, index) => copy(reads = ReadBuf(buffer, index) :: reads) - case ReadUni(uniform) => ??? + case ReadUni(uniform) => copy(reads = ReadUni(uniform) :: reads) def addWrite(write: Writes): SimContext = write match case WriteBuf(buffer, index, value) => val newArray = bufMap(buffer).updated(index, value) val newWrites = WriteBuf(buffer, index, value) :: writes copy(bufMap = bufMap.updated(buffer, newArray), writes = newWrites) - case WriteUni(uni, value) => ??? + case WriteUni(uni, value) => + val newWrites = WriteUni(uni, value) :: writes + copy(uniMap = uniMap.updated(uni, value), writes = newWrites) def addResult(res: Result) = copy(values = res :: values) def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) + def lookupUni(uniform: GUniform[?]): Result = uniMap(uniform) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala index 327a5a2d..2f691f6c 100644 --- a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -46,7 +46,7 @@ object Simulate: case Dynamic(source) => ??? case e: WhenExpr[?] => simWhen(e, sc) // returns new SimContext case e: ReadBuffer[?] => simReadBuffer(e, sc) // returns new SimContext - case e: ReadUniform[?] => simReadUniform(e) + case e: ReadUniform[?] => simReadUniform(e, sc) // returns new SimContext case e: GArrayElem[?] => simGArrayElem(e) case e: FoldSeq[?, ?] => simFoldSeq(e) case e: ComposeStruct[?] => simComposeStruct(e) @@ -150,14 +150,16 @@ object Simulate: case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, sc) - private def simReadBuffer(buf: ReadBuffer[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = buf match + private def simReadBuffer(e: ReadBuffer[?], sc: SimContext)(using exprMap: Map[Int, Result]): (Result, SimContext) = e match case ReadBuffer(buffer, index) => val i = exprMap(index.tree.treeid).asInstanceOf[Int] val newSc = sc.addRead(ReadBuf(buffer, i)) (newSc.lookup(buffer, i), newSc) - private def simReadUniform(uni: ReadUniform[?]): (Result, SimContext) = uni match - case ReadUniform(uniform) => ??? + private def simReadUniform(uni: ReadUniform[?], sc: SimContext): (Result, SimContext) = uni match + case ReadUniform(uniform) => + val newSc = sc.addRead(ReadUni(uniform)) + (newSc.lookupUni(uniform), newSc) private def simGArrayElem(gElem: GArrayElem[?]): (Result, SimContext) = gElem match case GArrayElem(index, i) => ??? From 0311c6a3fbf3ee0daa69138340374d73fef430af Mon Sep 17 00:00:00 2001 From: spamegg Date: Fri, 25 Jul 2025 15:19:23 +0300 Subject: [PATCH 13/13] change module dependency --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 84e95b4e..836e50e3 100644 --- a/build.sbt +++ b/build.sbt @@ -98,7 +98,7 @@ lazy val vscode = (project in file("cyfra-vscode")) lazy val interpreter = (project in file("cyfra-interpreter")) .settings(commonSettings) - .dependsOn(core) + .dependsOn(dsl, compiler) lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(commonSettings, runnerSettings)