Skip to content
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ on:
push:
branches:
- main
- dev
tags:
- "v*"
pull_request:
Expand Down
8 changes: 6 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,17 @@ lazy val vscode = (project in file("cyfra-vscode"))
.settings(commonSettings)
.dependsOn(foton)

lazy val interpreter = (project in file("cyfra-interpreter"))
.settings(commonSettings)
.dependsOn(dsl, compiler)

lazy val e2eTest = (project in file("cyfra-e2e-test"))
.settings(commonSettings, runnerSettings)
.dependsOn(runtime)
.dependsOn(runtime, interpreter)

lazy val root = (project in file("."))
.settings(name := "Cyfra")
.aggregate(compiler, dsl, foton, core, runtime, vulkan, examples)
.aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, interpreter)

e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.computenode.cyfra.e2e.interpreter

import io.computenode.cyfra.interpreter.*, Result.*
import io.computenode.cyfra.dsl.{*, given}
import binding.*, Value.*, gio.GIO, GIO.*
import Value.FromExpr.fromExpr, control.Scope

class InterpreterE2eTest extends munit.FunSuite:
test("interpret should not stack overflow"):
val pure = Pure(0)
var gio = FlatMap(pure, pure)
for _ <- 0 until 1000000 do gio = FlatMap(pure, gio)
val result = Interpreter.interpret(gio, 0)
val res = 0
val exp = 0
assert(res == exp, s"Expected $exp, got $res")
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.computenode.cyfra.e2e.interpreter

import io.computenode.cyfra.interpreter.*, Result.*
import io.computenode.cyfra.dsl.{*, given}, binding.{ReadBuffer, GBuffer}
import Value.FromExpr.fromExpr, control.Scope
import izumi.reflect.Tag

class SimulateE2eTest extends munit.FunSuite:
test("simulate binary operation arithmetic"):
val a: Int32 = 1
val b: Int32 = 2
val c: Int32 = 3
val d: Int32 = 4
val e: Int32 = 5
val f: Int32 = 6
val e1 = Diff(a, b)
val e2 = Sum(fromExpr(e1), c)
val e3 = Mul(f, fromExpr(e2))
val e4 = Div(fromExpr(e3), d)
val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4)
val (result, _) = Simulate.sim(expr, SimContext())
val expected = 2
assert(result == expected, s"Expected $expected, got $result")

test("simulate vec4, scalar, dot, extract scalar"):
val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f)
val (res1, _) = Simulate.sim(v1, SimContext())
val exp1 = Vector(1f, 2f, 3f, 4f)
assert(res1 == exp1, s"Expected $exp1, got $res1")

val i: Int32 = 2
val expr = ExtractScalar(fromExpr(v1), i)
val (res, _) = Simulate.sim(expr, SimContext())
val exp = 3f
assert(res == exp, s"Expected $exp, got $res")

val v2 = ScalarProd(fromExpr(v1), -1f)
val (res2, _) = Simulate.sim(v2, SimContext())
val exp2 = Vector(-1f, -2f, -3f, -4f)
assert(res2 == exp2, s"Expected $exp2, got $res2")

val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f)
val dot = DotProd(fromExpr(v1), fromExpr(v3))
val (res3a, _) = Simulate.sim(dot, SimContext())
val res3 = res3a.asInstanceOf[Float]
val exp3 = 0f
assert(Math.abs(res3 - exp3) < 0.001f, s"Expected $exp3, got $res3")

test("simulate bitwise ops"):
val a: Int32 = 5
val by: UInt32 = 3
val aNot = BitwiseNot(a)
val left = ShiftLeft(fromExpr(aNot), by)
val right = ShiftRight(fromExpr(aNot), by)
val and = BitwiseAnd(fromExpr(left), fromExpr(right))
val or = BitwiseOr(fromExpr(left), fromExpr(right))
val xor = BitwiseXor(fromExpr(and), fromExpr(or))

val (res, _) = Simulate.sim(xor, SimContext())
val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3))
assert(res == exp, s"Expected $exp, got $res")

test("simulate should not stack overflow"):
val a: Int32 = 1
var sum = Sum(a, a) // 2
for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum))
val (res, _) = Simulate.sim(sum, SimContext())
val exp = 1000002
assert(res == exp, s"Expected $exp, got $res")

test("simulate ReadBuffer"):
// We fake a GBuffer with an array
case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T]
val buffer = SimGBuffer[Int32]()
val array = (0 until 1024).toArray[Result]

val sc = SimContext().addBuffer(buffer, array)

val expr = ReadBuffer(buffer, 128)
val (res, newSc) = Simulate.sim(expr, sc)
val exp = 128
assert(res == exp, s"Expected $exp, got $res")

// the context should keep track of the read
assert(newSc.reads.contains(ReadBuf(buffer, 128)), "missing read")
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package io.computenode.cyfra.e2e.interpreter

import io.computenode.cyfra.interpreter.*, Result.*
import io.computenode.cyfra.dsl.{*, given}
import Value.FromExpr.fromExpr, control.Scope, binding.{GBuffer, ReadBuffer}
import izumi.reflect.Tag

class SimulateWhenE2eTest extends munit.FunSuite:
test("simulate when"):
val expr1 = WhenExpr(
when = 2 >= 1, // true
thenCode = Scope(ConstInt32(1)),
otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 <= 3))),
otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
otherwise = Scope(ConstInt32(3)),
)
val (res1, _) = Simulate.sim(expr1, SimContext())
val exp1 = 1
assert(res1 == exp1, s"Expected $exp1, got $res1")

test("simulate elseWhen first"):
val expr2 = WhenExpr(
when = 2 <= 1, // false
thenCode = Scope(ConstInt32(1)),
otherConds = List(Scope(ConstGB(3 >= 2)) /*true*/, Scope(ConstGB(1 <= 3))),
otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
otherwise = Scope(ConstInt32(3)),
)
val (res2, _) = Simulate.sim(expr2, SimContext())
val exp2 = 2
assert(res2 == exp2, s"Expected $exp2, got $res2")

test("simulate elseWhen second"):
val expr3 = WhenExpr(
when = 2 <= 1, // false
thenCode = Scope(ConstInt32(1)),
otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 <= 3))), // true
otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
otherwise = Scope(ConstInt32(3)),
)
val (res3, _) = Simulate.sim(expr3, SimContext())
val exp3 = 4
assert(res3 == exp3, s"Expected $exp3, got $res3")

test("simulate otherwise"):
val expr4 = WhenExpr(
when = 2 <= 1, // false
thenCode = Scope(ConstInt32(1)),
otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 >= 3))), // false
otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
otherwise = Scope(ConstInt32(3)),
)
val (res4, _) = Simulate.sim(expr4, SimContext())
val exp4 = 3
assert(res4 == exp4, s"Expected $exp4, got $res4")

test("simulate mixed arithmetic, reads and when"):
case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T]
val buffer = SimGBuffer[Int32]()
val array = (128 until 0 by -1).toArray[Result]

val sc = SimContext().addBuffer(buffer, array)

val a: Int32 = 32
val b: Int32 = 64
val c: Int32 = 4

val readExpr1 = ReadBuffer(buffer, a) // 96
val expr1 = Mul(c, fromExpr(readExpr1)) // 4 * 96 = 384

val readExpr2 = ReadBuffer(buffer, b) // 64
val expr2 = Sum(c, fromExpr(readExpr2)) // 4 + 64 = 68

val expr3 = Mod(fromExpr(expr2), 5) // 68 % 5 = 3

val cond1 = fromExpr(expr1) <= fromExpr(expr2) // 384 <= 68 false
val cond2 = Equal(fromExpr(expr1), fromExpr(expr2)) // 384 == 68 false
val cond3 = GreaterThanEqual(fromExpr(expr3), fromExpr(expr2)) // 3 >= 68 false

val expr = WhenExpr(
when = cond1, // false
thenCode = Scope(expr1),
otherConds = List(Scope(cond2), Scope(cond3)), // false false
otherCaseCodes = List(Scope(expr1), Scope(expr2)), // 384, 68
otherwise = Scope(expr3), // 3
)
val (res, newSc) = Simulate.sim(expr, sc)
val exp = 3
assert(res == exp, s"Expected $exp, got $res")

// There should be 2 reads in the simulation context
assert(newSc.reads.contains(ReadBuf(buffer, 32)))
assert(newSc.reads.contains(ReadBuf(buffer, 64)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.computenode.cyfra.interpreter

import io.computenode.cyfra.dsl.{*, given}
import binding.*, Value.*, gio.GIO, GIO.*
import Result.Result
import izumi.reflect.Tag

case class InvocResult(
invocId: Int,
instructions: List[Expression[?]] = Nil,
values: List[Result] = Nil,
writes: List[Writes] = Nil,
reads: List[Reads] = Nil,
)

case class InterpretResult(invocs: List[InvocResult] = Nil)

object Interpreter:
private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match
case Pure(value) =>
val (result, newSc) = Simulate.sim(value.asInstanceOf[Value], sc) // TODO needs fixing
newSc.addResult(result)

private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match
case WriteBuffer(buffer, index, value) =>
val (n, _) = Simulate.sim(index, SimContext()) // Int32, no reads/writes here, don't need resulting context
val i = n.asInstanceOf[Int]
val (res, newSc) = Simulate.sim(value, sc)
newSc.addWrite(WriteBuf(buffer, i, res))

private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match
case WriteUniform(uniform, value) =>
val (result, newSc) = Simulate.sim(value.asInstanceOf[Value], sc) // TODO needs fixing
newSc.addWrite(WriteUni(uniform, result))

private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match
case p: Pure[?] => interpretPure(p, sc)
case wb: WriteBuffer[?] => interpretWriteBuffer(wb, sc)
case wu: WriteUniform[?] => interpretWriteUniform(wu, sc)
case _ => throw IllegalArgumentException("interpretOne: invalid GIO")

@annotation.tailrec
private def interpretMany(gios: List[GIO[?]], sc: SimContext): SimContext = gios match
case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, sc)
case Repeat(n, f) :: tail =>
val (i, _) = Simulate.sim(n, SimContext()) // just Int32, no reads/writes
val int = i.asInstanceOf[Int]
val newGios = (0 until int).map(i => f(i)).toList
interpretMany(newGios ::: tail, sc)
case head :: tail =>
val newSc = interpretOne(head, sc)
interpretMany(tail, newSc)
case Nil => sc

def interpret(gio: GIO[?], invocId: Int): InvocResult =
val sc = interpretMany(List(gio), SimContext())
InvocResult(invocId, ???, sc.values, sc.writes, sc.reads)

def interpret(gio: GIO[?], invocIds: List[Int]): InterpretResult = InterpretResult(invocIds.map(interpret(gio, _)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.computenode.cyfra.interpreter

object Result:
export ScalarResult.*, VectorResult.*

type Result = ScalarRes | Vector[ScalarRes]

extension (r: Result)
def negate: Result = r match
case s: ScalarRes => s.neg
case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd

def bitNeg: Int = r match
case sr: ScalarRes => ~sr
case _ => throw IllegalArgumentException("bitNeg: wrong argument type")

def shiftLeft(by: Result): Int = (r, by) match
case (n: ScalarRes, b: ScalarRes) => n << b
case _ => throw IllegalArgumentException("shiftLeft: incompatible argument types")

def shiftRight(by: Result): Int = (r, by) match
case (n: ScalarRes, b: ScalarRes) => n >> b
case _ => throw IllegalArgumentException("shiftRight: incompatible argument types")

def bitAnd(that: Result): Int = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s & t
case _ => throw IllegalArgumentException("bitAnd: incompatible argument types")

def bitOr(that: Result): Int = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s | t
case _ => throw IllegalArgumentException("bitOr: incompatible argument types")

def bitXor(that: Result): Int = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s ^ t
case _ => throw IllegalArgumentException("bitXor: incompatible argument types")

def add(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s + t
case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t
case _ => throw IllegalArgumentException("add: incompatible argument types")

def sub(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s - t
case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t
case _ => throw IllegalArgumentException("sub: incompatible argument types")

def mul(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s * t
case _ => throw IllegalArgumentException("mul: incompatible argument types")

def div(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s / t
case _ => throw IllegalArgumentException("div: incompatible argument types")

def mod(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s % t
case _ => throw IllegalArgumentException("mod: incompatible argument types")

def scale(that: Result): Result = (r, that) match
case (v: Vector[ScalarRes], t: ScalarRes) => v scale t
case _ => throw IllegalArgumentException("scale: incompatible argument types")

def dot(that: Result): Result = (r, that) match
case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v dot t
case _ => throw IllegalArgumentException("dot: incompatible argument types")

def &&(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s && t
case _ => throw IllegalArgumentException("&&: incompatible argument types")

def ||(that: Result): Result = (r, that) match
case (s: ScalarRes, t: ScalarRes) => s || t
case _ => throw IllegalArgumentException("||: incompatible argument types")

def gt(that: Result): Boolean = (r, that) match
case (sr: ScalarRes, t: ScalarRes) => sr > t
case _ => throw IllegalArgumentException("gt: incompatible argument types")

def lt(that: Result): Boolean = (r, that) match
case (sr: ScalarRes, t: ScalarRes) => sr < t
case _ => throw IllegalArgumentException("lt: incompatible argument types")

def gteq(that: Result): Boolean = (r, that) match
case (sr: ScalarRes, t: ScalarRes) => sr >= t
case _ => throw IllegalArgumentException("gteq: incompatible argument types")

def lteq(that: Result): Boolean = (r, that) match
case (sr: ScalarRes, t: ScalarRes) => sr <= t
case _ => throw IllegalArgumentException("lteq: incompatible argument types")

def eql(that: Result): Boolean = (r, that) match
case (sr: ScalarRes, t: ScalarRes) => sr === t
case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t
case _ => throw IllegalArgumentException("eql: incompatible argument types")
Loading