Skip to content

Commit 198dbdb

Browse files
committed
Function to exclude edges with input-row in table.
1 parent 8f64000 commit 198dbdb

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44
[project]
55
name = "fk-graph"
66
packages = ["src/fk_graph"]
7-
version = "0.0.11"
7+
version = "0.0.12"
88
authors = [
99
{ name="Andrew Curtis", email="fk.graph@fastmail.com" },
1010
{ name="John C Thomas" },

src/fk_graph/edge_excluders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Edge Excluders.
2+
3+
Functions to pass in as an `exclude_edge` argument to `get_graph`.
4+
5+
The functions in this module that do not take any arguments can be passed in
6+
directly to `get_graph`. Those that do take arguments should be called with the
7+
required arguments, and the returned function passed to `get_graph`.
8+
"""
9+
10+
def input_row_is_in_tables(table_names):
11+
"""Checks whether input row-node is in list of tables.
12+
13+
This is typically used in cases where one wishes to include rows from a
14+
table as nodes in the graph, but *not* include the subsequent rows which
15+
can only be reached via these nodes.
16+
For example, if there is a `user` table with f-k relation to a `country`
17+
table, one might want to include the country that a user is from, but not
18+
also include all the other users that are from that country.
19+
20+
Args:
21+
table_names: A list of table names.
22+
23+
Returns:
24+
A function of the form (input_row, output_row) -> bool. The function will
25+
return `True` if and only if `input_row` is from a table with one of the
26+
names in `table_name`.
27+
"""
28+
def f(input_row, output_row):
29+
return _get_table_name_from_row(input_row) in table_names
30+
31+
return f
32+
33+
def _get_table_name_from_row(row):
34+
return row.__table__.name

tests/test_edge_excluders.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from unittest import TestCase
2+
3+
from sqlalchemy import create_engine
4+
from sqlalchemy.orm import (
5+
DeclarativeBase,
6+
Mapped,
7+
mapped_column,
8+
Session,
9+
)
10+
11+
from fk_graph.edge_excluders import input_row_is_in_tables
12+
13+
class TestInputRowIsInTable(TestCase):
14+
15+
def setUp(self):
16+
class Base(DeclarativeBase):
17+
pass
18+
19+
class TableA(Base):
20+
__tablename__ = "table_a"
21+
id: Mapped[int] = mapped_column(primary_key=True)
22+
23+
class TableB(Base):
24+
__tablename__ = "table_b"
25+
id: Mapped[int] = mapped_column(primary_key=True)
26+
27+
self.engine = create_engine("sqlite+pysqlite:///:memory:")
28+
Base.metadata.create_all(self.engine)
29+
self.TableA, self.TableB = TableA, TableB
30+
31+
def test_function_returns_true_if_input_row_from_table(self):
32+
with Session(self.engine) as session:
33+
table_a_row = self.TableA(id=1)
34+
table_b_row = self.TableB(id=1)
35+
session.add_all([table_a_row, table_b_row])
36+
session.commit()
37+
38+
with Session(self.engine) as session:
39+
table_a_row = session.get_one(self.TableA, 1)
40+
table_b_row = session.get_one(self.TableB, 1)
41+
42+
self.assertTrue(
43+
input_row_is_in_tables(["table_a"])(table_a_row, table_b_row)
44+
)
45+
46+
def test_function_returns_false_if_input_row_not_from_table(self):
47+
with Session(self.engine) as session:
48+
table_a_row = self.TableA(id=1)
49+
table_b_row = self.TableB(id=1)
50+
session.add_all([table_a_row, table_b_row])
51+
session.commit()
52+
53+
with Session(self.engine) as session:
54+
table_a_row = session.get_one(self.TableA, 1)
55+
table_b_row = session.get_one(self.TableB, 1)
56+
57+
self.assertFalse(
58+
input_row_is_in_tables(["table_b"])(table_a_row, table_b_row)
59+
)

0 commit comments

Comments
 (0)