Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
pull_request:

jobs:
format:
format_and_compile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -19,4 +19,4 @@ jobs:
with:
jvm: graalvm-java21
apps: sbt
- run: sbt "formatCheckAll"
- run: sbt "formatCheckAll; compile"
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ optIn.configStyleArguments = false
rewrite.rules = [RedundantBraces, RedundantParens, SortModifiers, PreferCurlyFors, Imports]
rewrite.sortModifiers.preset = styleGuide
rewrite.trailingCommas.style = always
rewrite.scala3.convertToNewSyntax = true

indent.defnSite = 2
newlines.inInterpolation = "avoid"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import scala.quoted.Expr

private[cyfra] object BlockBuilder:

def buildBlock(tree: E[_], providedExprIds: Set[Int] = Set.empty): List[E[_]] =
val allVisited = mutable.Map[Int, E[_]]()
def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] =
val allVisited = mutable.Map[Int, E[?]]()
val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0)
val q = mutable.Queue[E[_]]()
val q = mutable.Queue[E[?]]()
q.enqueue(tree)
allVisited(tree.treeid) = tree

Expand All @@ -28,8 +28,8 @@ private[cyfra] object BlockBuilder:
allVisited(childId) = child
q.enqueue(child)

val l = mutable.ListBuffer[E[_]]()
val roots = mutable.Queue[E[_]]()
val l = mutable.ListBuffer[E[?]]()
val roots = mutable.Queue[E[?]]()
allVisited.values.foreach: node =>
if inDegrees(node.treeid) == 0 then roots.enqueue(node)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ private[cyfra] case class Context(
voidFuncTypeRef: Int = -1,
workerIndexRef: Int = -1,
uniformVarRef: Int = -1,
constRefs: Map[(Tag[_], Any), Int] = Map(),
constRefs: Map[(Tag[?], Any), Int] = Map(),
exprRefs: Map[Int, Int] = Map(),
inBufferBlocks: List[ArrayBufferBlock] = List(),
outBufferBlocks: List[ArrayBufferBlock] = List(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ private[cyfra] object Opcodes {

def length = 1

override def toString = s"Word(${bytes.mkString(", ")}${if (bytes.length == 4) s" [i = ${BigInt(bytes).toInt}])" else ""}"
override def toString = s"Word(${bytes.mkString(", ")}${if bytes.length == 4 then s" [i = ${BigInt(bytes).toInt}])" else ""}"
}

private[cyfra] case class WordVariable(name: String) extends Words {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ private[cyfra] object SpirvTypes:
val UInt32Tag = summon[Tag[UInt32]]
val Float32Tag = summon[Tag[Float32]]
val GBooleanTag = summon[Tag[GBoolean]]
val Vec2TagWithoutArgs = summon[Tag[Vec2[_]]].tag.withoutArgs
val Vec3TagWithoutArgs = summon[Tag[Vec3[_]]].tag.withoutArgs
val Vec4TagWithoutArgs = summon[Tag[Vec4[_]]].tag.withoutArgs
val Vec2Tag = summon[Tag[Vec2[_]]]
val Vec3Tag = summon[Tag[Vec3[_]]]
val Vec4Tag = summon[Tag[Vec4[_]]]
val VecTag = summon[Tag[Vec[_]]]
val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs
val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs
val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs
val Vec2Tag = summon[Tag[Vec2[?]]]
val Vec3Tag = summon[Tag[Vec3[?]]]
val Vec4Tag = summon[Tag[Vec4[?]]]
val VecTag = summon[Tag[Vec[?]]]

val LInt32Tag = Int32Tag.tag
val LUInt32Tag = UInt32Tag.tag
Expand All @@ -37,7 +37,7 @@ private[cyfra] object SpirvTypes:
type Vec3C[T <: Value] = Vec3[T]
type Vec4C[T <: Value] = Vec4[T]

def scalarTypeDefInsn(tag: Tag[_], typeDefIndex: Int) = tag match {
def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match {
case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1)))
case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0)))
case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32)))
Expand All @@ -57,9 +57,9 @@ private[cyfra] object SpirvTypes:
case v if v <:< LVecTag =>
vecSize(v) * typeStride(v.typeArgs.head)

def typeStride(tag: Tag[_]): Int = typeStride(tag.tag)
def typeStride(tag: Tag[?]): Int = typeStride(tag.tag)

def toWord(tpe: Tag[_], value: Any): Words = tpe match {
def toWord(tpe: Tag[?], value: Any): Words = tpe match {
case t if t == Int32Tag =>
IntWord(value.asInstanceOf[Int])
case t if t == UInt32Tag =>
Expand All @@ -73,7 +73,7 @@ private[cyfra] object SpirvTypes:
Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray)
}

def defineScalarTypes(types: List[Tag[_]], context: Context): (List[Words], Context) =
def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) =
val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag)
(basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) =>
val typeDefIndex = ctx.nextResultId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,29 @@ import scala.util.Random
private[cyfra] object DSLCompiler:

// TODO: Not traverse same fn scopes for each fn call
private def getAllExprsFlattened(root: E[_], visitDetached: Boolean): List[E[_]] =
private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] =
var blockI = 0
val allScopesCache = mutable.Map[Int, List[E[_]]]()
val allScopesCache = mutable.Map[Int, List[E[?]]]()
val visited = mutable.Set[Int]()
@tailrec
def getAllScopesExprsAcc(toVisit: List[E[_]], acc: List[E[_]] = Nil): List[E[_]] = toVisit match
def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match
case Nil => acc
case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc)
case e :: tail =>
if (allScopesCache.contains(root.treeid))
return allScopesCache(root.treeid)
if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid)
val eScopes = e.introducedScopes
val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached)
val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr)
val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc
visited += e.treeid
blockI += 1
if (blockI % 100 == 0)
allScopesCache.update(e.treeid, result)
if blockI % 100 == 0 then allScopesCache.update(e.treeid, result)
getAllScopesExprsAcc(newToVisit, result)
val result = root :: getAllScopesExprsAcc(root :: Nil)
allScopesCache(root.treeid) = result
result

def compile(tree: Value, inTypes: List[Tag[_]], outTypes: List[Tag[_]], uniformSchema: GStructSchema[_]): ByteBuffer =
def compile(tree: Value, inTypes: List[Tag[?]], outTypes: List[Tag[?]], uniformSchema: GStructSchema[?]): ByteBuffer =
val treeExpr = tree.tree
val allExprs = getAllExprsFlattened(treeExpr, visitDetached = true)
val typesInCode = allExprs.map(_.tag).distinct
Expand All @@ -59,8 +57,8 @@ private[cyfra] object DSLCompiler:
val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext)
val structsInCode =
(allExprs.collect {
case cs: ComposeStruct[_] => cs.resultSchema
case gf: GetField[_, _] => gf.resultSchema
case cs: ComposeStruct[?] => cs.resultSchema
case gf: GetField[?, ?] => gf.resultSchema
} :+ uniformSchema).distinct
val (structDefs, structCtx) = defineStructTypes(structsInCode, typedContext)
val structNames = getStructNames(structsInCode, structCtx)
Expand Down
Loading