diff --git a/build.sbt b/build.sbt index 45df6f4..55ff9c6 100644 --- a/build.sbt +++ b/build.sbt @@ -14,6 +14,15 @@ val `oolong-bson` = (project in file("oolong-bson")) Test / fork := true, ) +val `oolong-json` = (project in file("oolong-json")) + .settings(Settings.common) + .settings( + libraryDependencies ++= Seq( + "org.apache.commons" % "commons-text" % "1.9" + ), + Test / fork := true + ) + val `oolong-core` = (project in file("oolong-core")) .settings(Settings.common) .dependsOn(`oolong-bson`) @@ -40,6 +49,18 @@ val `oolong-mongo` = (project in file("oolong-mongo")) Test / fork := true ) +val `oolong-elasticsearch` = (project in file("oolong-elasticsearch")) + .settings(Settings.common) + .dependsOn(`oolong-core`, `oolong-json`) + .settings( + libraryDependencies ++= Seq( + "org.scalatest" %% "scalatest" % "3.2.11" % Test, + "org.slf4j" % "slf4j-api" % "1.7.36" % Test, + "org.slf4j" % "slf4j-simple" % "1.7.36" % Test, + ), + Test / fork := true + ) + val root = (project in file(".")) .settings(Settings.common) .aggregate(`oolong-bson`, `oolong-core`, `oolong-mongo`) diff --git a/oolong-core/src/main/scala/ru/tinkoff/oolong/Utils.scala b/oolong-core/src/main/scala/ru/tinkoff/oolong/Utils.scala index 4fc90e8..d86aedf 100644 --- a/oolong-core/src/main/scala/ru/tinkoff/oolong/Utils.scala +++ b/oolong-core/src/main/scala/ru/tinkoff/oolong/Utils.scala @@ -179,3 +179,7 @@ private[oolong] object Utils: def unapply(expr: Expr[Pattern])(using q: Quotes): Option[Pattern] = import q.reflect.* AsRegexPattern.unapply(expr) + + extension [A](sq: Seq[A]) { + def pforall(pf: PartialFunction[A, Boolean]): Boolean = sq.forall(pf.applyOrElse(_, _ => false)) + } diff --git a/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/ElasticQueryCompiler.scala b/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/ElasticQueryCompiler.scala new file mode 100644 index 0000000..5fb0fd5 --- /dev/null +++ b/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/ElasticQueryCompiler.scala @@ -0,0 +1,162 @@ +package ru.tinkoff.oolong.elasticsearch + +import scala.annotation.nowarn +import scala.quoted.Expr +import scala.quoted.Quotes + +import org.bson.json.JsonMode + +import ru.tinkoff.oolong.* +import ru.tinkoff.oolong.Utils.* +import ru.tinkoff.oolong.elasticsearch.ElasticQueryNode as EQN + +object ElasticQueryCompiler extends Backend[QExpr, ElasticQueryNode, JsonNode] { + override def opt(ast: QExpr)(using quotes: Quotes): ElasticQueryNode = { + import quotes.reflect.* + + ast match { + case QExpr.Prop(path) => EQN.Field(path) + case QExpr.Eq(QExpr.Prop(path), QExpr.Constant(s)) => EQN.Term(EQN.Field(path), EQN.Constant(s)) + case QExpr.Ne(QExpr.Prop(path), QExpr.Constant(s)) => + EQN.Bool(mustNot = EQN.Term(EQN.Field(path), EQN.Constant(s)) :: Nil) + case QExpr.Gte(QExpr.Prop(path), QExpr.Constant(s)) => EQN.Range(EQN.Field(path), gte = Some(EQN.Constant(s))) + case QExpr.Lte(QExpr.Prop(path), QExpr.Constant(s)) => EQN.Range(EQN.Field(path), lte = Some(EQN.Constant(s))) + case QExpr.Gt(QExpr.Prop(path), QExpr.Constant(s)) => EQN.Range(EQN.Field(path), gt = Some(EQN.Constant(s))) + case QExpr.Lt(QExpr.Prop(path), QExpr.Constant(s)) => EQN.Range(EQN.Field(path), lt = Some(EQN.Constant(s))) + case QExpr.And(exprs) => EQN.Bool(must = exprs map opt) + case QExpr.Or(exprs) => EQN.Bool(should = exprs map opt) + case QExpr.Not(expr) => EQN.Bool(mustNot = opt(expr) :: Nil) + case QExpr.Exists(QExpr.Prop(path), QExpr.Constant(true)) => EQN.Exists(EQN.Field(path)) + case QExpr.Exists(QExpr.Prop(path), QExpr.Constant(false)) => + EQN.Bool(mustNot = EQN.Exists(EQN.Field(path)) :: Nil) + case unhandled => report.errorAndAbort("Unprocessable") + } + } + + def getField(f: QExpr)(using quotes: Quotes): EQN.Field = + import quotes.reflect.* + f match + case QExpr.Prop(path) => EQN.Field(path) + case _ => report.errorAndAbort("Field is of wrong type") + + override def render(node: ElasticQueryNode)(using quotes: Quotes): String = { + import quotes.reflect.* + + node match { + case EQN.Term(EQN.Field(path), x) => + s"""{ "term": {"${path.mkString(".")}": ${render(x)} } }""" + case EQN.Bool(must, should, mustNot) => + s"""{"must": [${must.map(render).mkString(", ")}], "should": [${should + .map(render) + .mkString(", ")}], "must_not": [${mustNot.map(render).mkString(", ")}]}""" + case EQN.Constant(s: String) => "\"" + s + "\"" + case EQN.Constant(s: Any) => s.toString + case EQN.Exists(EQN.Field(path)) => + s"""{ "exists": { "field": "${path.mkString(".")}" }}""" + case EQN.Range(EQN.Field(path), gt, gte, lt, lte) => + val bounds = Seq( + renderKeyMap(""""gt"""", gt), + renderKeyMap(""""gte"""", gte), + renderKeyMap(""""lt"""", lt), + renderKeyMap(""""lte"""", lte) + ).flatten + s"""{"range": {"${path.mkString(".")}": {${bounds.mkString(",")}}}}""" + case EQN.Field(field) => + // TODO: adjust error message + report.errorAndAbort(s"There is no filter condition on field ${field.mkString(".")}") + case _ => "AST can't be rendered" + } + } + + private def renderKeyMap(key: String, node: Option[ElasticQueryNode])(using quotes: Quotes): Option[String] = + node.map(render(_)).map(v => s"""$key:$v""") + + override def target(optRepr: ElasticQueryNode)(using quotes: Quotes): Expr[JsonNode] = { + import quotes.reflect.* + + optRepr match { + case bool: EQN.Bool => + '{ + JsonNode.obj( + "bool" -> JsonNode.obj( + "must" -> JsonNode.Arr(${ Expr.ofSeq(bool.must.map(target)) }), + "should" -> JsonNode.Arr(${ Expr.ofSeq(bool.should.map(target)) }), + "must_not" -> JsonNode.Arr(${ Expr.ofSeq(bool.mustNot.map(target)) }), + ) + ) + } + case EQN.Term(EQN.Field(path), x) => + '{ JsonNode.obj("term" -> JsonNode.obj(${ Expr(path.mkString(".")) } -> ${ handleValues(x) })) } + case EQN.Exists(EQN.Field(path)) => + '{ JsonNode.obj("exists" -> JsonNode.obj("field" -> JsonNode.Str(${ Expr(path.mkString(".")) }))) } + case EQN.Range(EQN.Field(path), gt, gte, lt, lte) => + '{ + JsonNode.obj( + "range" -> JsonNode.obj( + ${ Expr(path.mkString(".")) } -> JsonNode.obj( + "gt" -> ${ gt.map(handleValues(_)).getOrElse('{ JsonNode.`null` }) }, + "gte" -> ${ gte.map(handleValues(_)).getOrElse('{ JsonNode.`null` }) }, + "lt" -> ${ lt.map(handleValues(_)).getOrElse('{ JsonNode.`null` }) }, + "lte" -> ${ lte.map(handleValues(_)).getOrElse('{ JsonNode.`null` }) } + ) + ) + ) + } + case _ => report.errorAndAbort("given node can't be in that position") + } + } + + def handleValues(expr: ElasticQueryNode)(using q: Quotes): Expr[JsonNode] = + import q.reflect.* + + expr match { + case EQN.Constant(i: Long) => + '{ JsonNode.Num.apply(BigDecimal.apply(${ Expr(i: Long) })) } + case EQN.Constant(i: Int) => + '{ JsonNode.Num.apply(BigDecimal.apply(${ Expr(i: Int) })) } + case EQN.Constant(i: Short) => + '{ JsonNode.Num.apply(BigDecimal.apply(${ Expr(i: Short) })) } + case EQN.Constant(i: Byte) => + '{ JsonNode.Num.apply(BigDecimal.apply(${ Expr(i: Byte) })) } + case EQN.Constant(i: Double) => + '{ JsonNode.Num.apply(BigDecimal.apply(${ Expr(i: Double) })) } + case EQN.Constant(i: Float) => + '{ JsonNode.Num.apply(BigDecimal.apply(${ Expr(i: Float) })) } + case EQN.Constant(i: String) => + '{ JsonNode.Str.apply(${ Expr(i: String) }) } + case EQN.Constant(i: Boolean) => + '{ JsonNode.Bool.apply(${ Expr(i: Boolean) }) } + case _ => report.errorAndAbort(s"Given type is not literal constant") + } + + override def optimize(query: ElasticQueryNode): ElasticQueryNode = query match { + case EQN.Bool(must, should, mustNot) => + val mustBuilder = List.newBuilder[ElasticQueryNode] + val shouldBuilder = List.newBuilder[ElasticQueryNode].addAll(should.map(optimize)) + val mustNotBuilder = List.newBuilder[ElasticQueryNode].addAll(mustNot.map(optimize)) + + lazy val orCount = must.count { + case EQN.Bool.Or(_) => true + case _ => false + } + + for (mp <- must.map(optimize)) mp match { + case EQN.Bool.And(must2) => + must2.foreach(mustBuilder += _) + case EQN.Bool.Or(should2) if should.isEmpty && orCount == 1 => + should2.foreach(shouldBuilder += _) + // !(a || b || c) => !a && !b && !c + case EQN.Bool.Or(should2) if should2.pforall { case EQN.Bool.Not(_) => true } => + should2.foreach { case EQN.Bool.Not(p) => + mustNotBuilder += p + }: @nowarn + case EQN.Bool.Not(not) => + mustNotBuilder += not + case other => mustBuilder += other + } + + EQN.Bool(mustBuilder.result, shouldBuilder.result, mustNotBuilder.result) + + case other => other + } +} diff --git a/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/ElasticQueryNode.scala b/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/ElasticQueryNode.scala new file mode 100644 index 0000000..b825ab5 --- /dev/null +++ b/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/ElasticQueryNode.scala @@ -0,0 +1,49 @@ +package ru.tinkoff.oolong.elasticsearch + +sealed trait ElasticQueryNode + +object ElasticQueryNode { + case class Field(path: List[String]) extends ElasticQueryNode + + case class Term(field: Field, expr: ElasticQueryNode) extends ElasticQueryNode + case class Constant[T](s: T) extends ElasticQueryNode + + case class Exists(x: ElasticQueryNode) extends ElasticQueryNode + + case class Bool( + must: List[ElasticQueryNode] = Nil, + should: List[ElasticQueryNode] = Nil, + mustNot: List[ElasticQueryNode] = Nil + ) extends ElasticQueryNode + + object Bool { + object And { + def unapply(bool: Bool): Option[List[ElasticQueryNode]] = bool match { + case Bool(and, Nil, Nil) => Some(and) + case _ => None + } + } + + object Or { + def unapply(bool: Bool): Option[List[ElasticQueryNode]] = bool match { + case Bool(Nil, or, Nil) => Some(or) + case _ => None + } + } + + object Not { + def unapply(bool: Bool): Option[ElasticQueryNode] = bool match { + case Bool(Nil, Nil, List(not)) => Some(not) + case _ => None + } + } + } + + case class Range( + field: Field, + gt: Option[ElasticQueryNode] = None, + gte: Option[ElasticQueryNode] = None, + lt: Option[ElasticQueryNode] = None, + lte: Option[ElasticQueryNode] = None + ) extends ElasticQueryNode +} diff --git a/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/QueryCompiler.scala b/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/QueryCompiler.scala new file mode 100644 index 0000000..2602879 --- /dev/null +++ b/oolong-elasticsearch/src/main/scala/ru/tinkoff/oolong/elasticsearch/QueryCompiler.scala @@ -0,0 +1,30 @@ +package ru.tinkoff.oolong.elasticsearch + +import scala.quoted.* + +import ru.tinkoff.oolong.* +import ru.tinkoff.oolong.dsl.* + +/** + * Compile a ES query. + * @param input + * Scala code describing the query. + */ +inline def query[Doc](inline input: Doc => Boolean): JsonNode = ${ queryImpl('input) } + +private[oolong] def queryImpl[Doc: Type](input: Expr[Doc => Boolean])(using quotes: Quotes): Expr[JsonNode] = { + import quotes.reflect.* + import ElasticQueryCompiler.* + + val parser = new DefaultAstParser + + val ast = parser.parseQExpr(input) + val optimizedAst = LogicalOptimizer.optimize(ast) + + val optRepr = opt(optimizedAst) + val optimized = optimize(optRepr) + + report.info("Optimized AST:\n" + pprint(optimizedAst) + "\nGenerated query:\n" + render(optimized)) + + '{ JsonNode.obj("query" -> ${ target(optimized) }) } +} diff --git a/oolong-elasticsearch/src/test/scala/ru/tinkoff/oolong/elasticsearch/QuerySpec.scala b/oolong-elasticsearch/src/test/scala/ru/tinkoff/oolong/elasticsearch/QuerySpec.scala new file mode 100644 index 0000000..58c7971 --- /dev/null +++ b/oolong-elasticsearch/src/test/scala/ru/tinkoff/oolong/elasticsearch/QuerySpec.scala @@ -0,0 +1,74 @@ +package ru.tinkoff.oolong.elasticsearch + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers.shouldBe + +import ru.tinkoff.oolong.dsl.* + +class QuerySpec extends AnyFunSuite { + test("Term query") { + val q = query[TestClass](_.field2 == 2) + + q.render shouldBe """{"query":{"term":{"field2":2}}}""" + } + + test("$$ query") { + val q = query[TestClass](c => c.field1 == "check" && c.field2 == 42) + + q.render shouldBe """{"query":{"bool":{"must":[{"term":{"field1":"check"}},{"term":{"field2":42}}],"should":[],"must_not":[]}}}""" + } + + test("|| query") { + val q = query[TestClass](c => c.field1 == "check" || c.field2 == 42) + + q.render shouldBe """{"query":{"bool":{"must":[],"should":[{"term":{"field1":"check"}},{"term":{"field2":42}}],"must_not":[]}}}""" + } + + test("!= query") { + val q = query[TestClass](_.field2 != 2) + + q.render shouldBe """{"query":{"bool":{"must":[],"should":[],"must_not":[{"term":{"field2":2}}]}}}""" + } + + test("Composite boolean query") { + val q = query[TestClass](c => c.field1 == "check" && (c.field2 == 42 || c.field3.innerField == "inner")) + + q.render shouldBe """{"query":{"bool":{"must":[{"term":{"field1":"check"}}],"should":[{"term":{"field2":42}},{"term":{"field3.innerField":"inner"}}],"must_not":[]}}}""" + } + + test(".isDefined query") { + val q = query[TestClass](_.field3.optionalInnerField.isDefined) + + q.render shouldBe """{"query":{"exists":{"field":"field3.optionalInnerField"}}}""" + } + + test(".isEmpty query") { + val q = query[TestClass](_.field3.optionalInnerField.isEmpty) + + q.render shouldBe """{"query":{"bool":{"must":[],"should":[],"must_not":[{"exists":{"field":"field3.optionalInnerField"}}]}}}""" + } + + test("> query") { + val q = query[TestClass](_.field2 > 4) + + q.render shouldBe """{"query":{"range":{"field2":{"gt":4,"gte":null,"lt":null,"lte":null}}}}""" + } + + test(">= query") { + val q = query[TestClass](_.field2 >= 4) + + q.render shouldBe """{"query":{"range":{"field2":{"gt":null,"gte":4,"lt":null,"lte":null}}}}""" + } + + test("< query") { + val q = query[TestClass](_.field2 < 4) + + q.render shouldBe """{"query":{"range":{"field2":{"gt":null,"gte":null,"lt":4,"lte":null}}}}""" + } + + test("<= query") { + val q = query[TestClass](_.field2 <= 4) + + q.render shouldBe """{"query":{"range":{"field2":{"gt":null,"gte":null,"lt":null,"lte":4}}}}""" + } +} diff --git a/oolong-elasticsearch/src/test/scala/ru/tinkoff/oolong/elasticsearch/TestDomain.scala b/oolong-elasticsearch/src/test/scala/ru/tinkoff/oolong/elasticsearch/TestDomain.scala new file mode 100644 index 0000000..72552b4 --- /dev/null +++ b/oolong-elasticsearch/src/test/scala/ru/tinkoff/oolong/elasticsearch/TestDomain.scala @@ -0,0 +1,13 @@ +package ru.tinkoff.oolong.elasticsearch + +case class TestClass( + field1: String, + field2: Int, + field3: InnerClass, + field4: List[Int] +) + +case class InnerClass( + innerField: String, + optionalInnerField: Option[Int] +) diff --git a/oolong-json/src/main/scala/ru/tinkoff/oolong/JsonNode.scala b/oolong-json/src/main/scala/ru/tinkoff/oolong/JsonNode.scala new file mode 100644 index 0000000..d9d45ed --- /dev/null +++ b/oolong-json/src/main/scala/ru/tinkoff/oolong/JsonNode.scala @@ -0,0 +1,35 @@ +package ru.tinkoff.oolong + +import scala.compiletime.ops.boolean + +import org.apache.commons.text.StringEscapeUtils + +sealed private[oolong] trait JsonNode { + def render: String +} + +private[oolong] object JsonNode { + case object Null extends JsonNode { + override def render: String = "null" + } + case class Bool(value: Boolean) extends JsonNode { + override def render: String = value.toString + } + case class Num(value: BigDecimal) extends JsonNode { + override def render: String = value.toString + } + case class Str(value: String) extends JsonNode { + override def render: String = s"\"${StringEscapeUtils.escapeJson(value)}\"" + } + case class Arr(value: Seq[JsonNode]) extends JsonNode { + override def render: String = value.map(_.render).mkString("[", ",", "]") + } + case class Obj(value: Map[String, JsonNode]) extends JsonNode { + override def render: String = value.map((k, v) => s"\"$k\":${v.render}").mkString("{", ",", "}") + } + + val `null`: JsonNode = Null + + def obj(head: (String, JsonNode), tail: (String, JsonNode)*): Obj = + Obj((head +: tail).to(Map)) +}