Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,3 @@ fquery with Django and get easy access to graphql functionality
This project is made available under the Apache License, version 2.0.

See [LICENSE.txt](license.txt) for details.

155 changes: 155 additions & 0 deletions fquery/cypher_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) Arun Sharma, 2025
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import operator

from .visitor import Visitor

# inspired from pandas.core.computation.ops
_cmp_ops_syms = (">", "<", ">=", "<=", "==", "!=")
_cmp_ops_funcs = (
operator.gt,
operator.lt,
operator.ge,
operator.le,
operator.eq,
operator.ne,
)
_cmp_ops_dict = dict(zip(_cmp_ops_syms, _cmp_ops_funcs))


class CypherBuilderVisitor(Visitor):
def __init__(self, id1s):
self.cypher = None
self.match_parts = []
self.current_node = "u"
self.where_clauses = []
self.return_clause = ""
self.order_by_clause = ""
self.limit_clause = ""
self.visited = set()
self.node_counter = 0
self.root_label = None

@staticmethod
def table_from_query(query):
query_name = query.__class__.__name__.lower()
# UserQuery -> user -> User
query_name = query_name.split("query")[0]
return query_name.capitalize()

def _get_next_node_var(self):
self.node_counter += 1
return f"n{self.node_counter}"

async def visit_leaf(self, query):
if not self.root_label:
self.root_label = self.table_from_query(query)
self.match_parts = [f"({self.current_node}:{self.root_label})"]

if query in self.visited:
# Prevent infinite recursion
return
else:
self.visited.add(query)
for q in query.edges:
await self.visit(q)

async def visit_project(self, query):
await self.visit(query.child)
proj = ", ".join(
[
f"{self.current_node}.{x}" if x != ":id" else f"{self.current_node}.id"
for x in query.projector
]
)
self.return_clause = f"RETURN {proj}"

async def visit_take(self, query):
await self.visit(query.child)
self.limit_clause = f"LIMIT {query._count}"

async def visit_where(self, query):
await self.visit(query.child)
# TODO: more general lazy expression evaluator
left, op, right = query._expr.value.split()
right = ast.literal_eval(right)
table, field = left.split(".") if "." in left else (self.cypher, left)
self.where_clauses.append(f"{self.current_node}.{field} {op} {right}")

async def visit_order_by(self, query):
await self.visit(query.child)
key = query._expr.value
table, field = key.split(".") if "." in key else (self.cypher, key)
self.order_by_clause = f"ORDER BY {self.current_node}.{field}"

async def visit_edge(self, query):
# Ensure we have the root label and initial match part
if not self.root_label:
# Find the root query
root_query = query
while hasattr(root_query, "child") and root_query.child:
root_query = root_query.child
if hasattr(root_query, "__class__"):
self.root_label = self.table_from_query(root_query)
self.match_parts = [f"({self.current_node}:{self.root_label})"]

edge_name = query.edge_name

# Check if this is part of a multi-hop pattern of the same edge type
# Look ahead to see if the child has another edge of the same type
has_same_edge_child = (
hasattr(query, "child")
and hasattr(query.child, "OP")
and query.child.OP.name == "EDGE"
and query.child.edge_name == edge_name
)

if has_same_edge_child:
# Count the total number of consecutive edges of the same type
hops = 1 # current edge
current_query = query.child
while (
hasattr(current_query, "OP")
and current_query.OP.name == "EDGE"
and current_query.edge_name == edge_name
):
hops += 1
current_query = current_query.child

# This is the start of a multi-hop pattern (e.g., friend-of-friend, etc.)
self.match_parts = [
f"(a:{self.root_label})",
f"[e:{edge_name.upper()}*{hops}..{hops}]",
f"(b:{self.root_label})",
]
self.current_node = "b"
# Skip visiting the intermediate edges and visit the query after the chain
await self.visit(current_query)
else:
# Regular edge traversal
child_label = self.table_from_query(query._unbound)
next_node = self._get_next_node_var()
relationship = f"[:{edge_name.upper()}]"
self.match_parts.append(f"{relationship}->({next_node}:{child_label})")
self.current_node = next_node
await self.visit(query.child)

async def visit_union(self, query):
# UNION in Cypher - this is complex, would need to handle multiple queries
# For now, just visit the child
await self.visit(query.child)

async def visit_count(self, query):
await self.visit(query.child)
self.return_clause = "RETURN count(*)"

async def visit_nest(self, query):
# Nesting - not directly supported in Cypher
await self.visit(query.child)

async def visit_let(self, query):
# LET - for renaming, could be handled with AS in RETURN
await self.visit(query.child)
17 changes: 17 additions & 0 deletions fquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Dict, List, Optional, Tuple, Type, Union

from .async_utils import wait_for
from .cypher_builder import CypherBuilderVisitor
from .execute import AbstractSyntaxTreeVisitor
from .malloy_builder import MalloyBuilderVisitor
from .polars_builder import PolarsBuilderVisitor
Expand Down Expand Up @@ -257,6 +258,22 @@ def to_malloy(self) -> str:
wait_for(visitor.visit(self))
return visitor.malloy

def to_cypher(self) -> str:
visitor = CypherBuilderVisitor([])
wait_for(visitor.visit(self))
# Build the final query
match_pattern = "MATCH " + "-".join(visitor.match_parts)
qstr = match_pattern
if visitor.where_clauses:
qstr += "\nWHERE " + " AND ".join(visitor.where_clauses)
if visitor.return_clause:
qstr += "\n" + visitor.return_clause
if visitor.order_by_clause:
qstr += "\n" + visitor.order_by_clause
if visitor.limit_clause:
qstr += "\n" + visitor.limit_clause
return qstr

def to_polars(self) -> Tree:
visitor = PolarsBuilderVisitor([])
wait_for(visitor.visit(self))
Expand Down
81 changes: 81 additions & 0 deletions tests/test_cypher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import random
import textwrap
import unittest

from .mock_user import UserQuery


class CypherTests(unittest.TestCase):
def setUp(self):
random.seed(100)
self.maxDiff = None

def test_project(self):
cypher_q = (
UserQuery(range(1, 10))
.project([":id", "name"])
.where(ast.Expr("user.age >= 16"))
.order_by(ast.Expr("user.age"))
.take(3)
.to_cypher()
)
expected = textwrap.dedent(
"""\
MATCH (u:User)
WHERE u.age >= 16
RETURN u.id, u.name
ORDER BY u.age
LIMIT 3"""
)
self.assertEqual(expected, cypher_q)

def test_sync_edge_project(self):
cypher_q = (
UserQuery(range(1, 5))
.edge("friends")
.project(["name", ":id"])
.take(3)
.to_cypher()
)
expected = textwrap.dedent(
"""\
MATCH (u:User)-[:FRIENDS]->(n1:User)
RETURN n1.name, n1.id
LIMIT 3"""
)
self.assertEqual(expected, cypher_q)

def test_sync_two_hop_project(self):
cypher_q = (
UserQuery([1])
.edge("friends")
.edge("friends")
.project(["name", ":id"])
.take(3)
.to_cypher()
)
expected = textwrap.dedent(
"""\
MATCH (a:User)-[e:FRIENDS*2..2]-(b:User)
RETURN b.name, b.id
LIMIT 3"""
)
self.assertEqual(expected, cypher_q)

def test_count(self):
cypher_q = UserQuery(range(1, 10)).count().to_cypher()
expected = textwrap.dedent(
"""\
MATCH (u:User)
RETURN count(*)"""
)
self.assertEqual(expected, cypher_q)


if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion tests/test_data/test_data_edge_count.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@
]
}
]

1 change: 0 additions & 1 deletion tests/test_data/test_data_edge_let.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@
]
}
]