From 938ccfa9d329a87fdceed2ef67fad7d6258bddf7 Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Mon, 24 Jan 2022 02:12:49 +1000 Subject: [PATCH 1/7] Add first tests for sql and add partial implementation of sql module https://www.notion.so/SQL-494243006ef04f68b803fff9104e5597 --- EGE/SQL/Aggregate.py | 29 +++++++ EGE/SQL/RandomTable.py | 69 +++++++++++++++ EGE/SQL/Table.py | 193 +++++++++++++++++++++++++++++++++++++++++ EGE/SQL/Utils.py | 16 ++++ EGE/SQL/__init__.py | 0 EGE/Utils.py | 14 ++- EGE/test/test_SQL.py | 72 +++++++++++++++ 7 files changed, 389 insertions(+), 4 deletions(-) create mode 100644 EGE/SQL/Aggregate.py create mode 100644 EGE/SQL/RandomTable.py create mode 100644 EGE/SQL/Table.py create mode 100644 EGE/SQL/Utils.py create mode 100644 EGE/SQL/__init__.py create mode 100644 EGE/test/test_SQL.py diff --git a/EGE/SQL/Aggregate.py b/EGE/SQL/Aggregate.py new file mode 100644 index 0000000..5a61dab --- /dev/null +++ b/EGE/SQL/Aggregate.py @@ -0,0 +1,29 @@ +class Aggregate: + @staticmethod + def call(self): # abstract + pass + +class Count(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Sum(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Avg(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Min(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Max(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() \ No newline at end of file diff --git a/EGE/SQL/RandomTable.py b/EGE/SQL/RandomTable.py new file mode 100644 index 0000000..5656106 --- /dev/null +++ b/EGE/SQL/RandomTable.py @@ -0,0 +1,69 @@ +from EGE.Random import Random + +class BaseTable: + columns: list + def __init__(self, rnd: Random, col_count, row_count, name): + self.rnd = rnd + self.columns = self.get_columns() + self.fields = [ + self.columns[0], + rnd.pick_n( + col_count - 1, + self.columns[1:len(self.columns)] + ) + ] + row_sources = [ row for row in self.get_rows_array() + if len(row) >= row_count ] + self.rows = rnd.pick_n(row_count, row_sources) + + + def get_columns(self) -> list: + return [] + + def get_rows_array(self): + pass + +class Products(BaseTable): + pass + +class Jobs(BaseTable): + pass + +class SalesMonth(BaseTable): + pass + +class Cities(BaseTable): + pass + +class People(BaseTable): + pass + +class Subjects(BaseTable): + pass + +class Marks(BaseTable): + pass + +class ParticipantsMonth(BaseTable): + pass + +def ok_table(table: BaseTable, rows, cols): + t_cols = table.get_columns() + t_rows = table.get_rows_array() + return t_cols >= cols and any([len(t_rows) > rows]) + +def pick(rnd: Random, *args): + return rnd.pick([ + lambda *args2: Products(args2), + lambda *args2: Jobs(args2), + lambda *args2: SalesMonth(args2), + lambda *args2: Cities(args2), + lambda *args2: People(args2), + lambda *args2: Subjects(args2), + lambda *args2: Marks(args2), + lambda *args2: ParticipantsMonth(args2), + ])(args) + +def create_table(rnd: Random, rows: int, columns: int): + table = pick(rnd, rows, columns) + return table diff --git a/EGE/SQL/Table.py b/EGE/SQL/Table.py new file mode 100644 index 0000000..5f7a04f --- /dev/null +++ b/EGE/SQL/Table.py @@ -0,0 +1,193 @@ +from EGE.GenBase import EGEError +from EGE.Prog import CallFuncAggregate, make_expr +from EGE.Random import Random +from EGE.Utils import aggregate_function + +class Field: + def __init__(self, attr): + self.name: str = '' + self.name_alias: str = '' + if isinstance(attr, dict): + self.name = attr['name'] + self.name_alias = attr['name_alias'] + else: + self.name = attr + + def to_lang(self): + return self.name_alias + '.' + self.name if self.name_alias else self.name + + def __str__(self): + return self.name + + def __repr__(self): + if self.name_alias: + return f"Field({{'name': '{self.name}', 'name_alias': '{self.name_alias}'}})" + return f"Field('{self.name}')" + + +class Table: + def __init__(self, fields: list, **kwargs): + if not fields or fields is None: + raise EGEError('No fields') + + self.fields: list[Field] = [ self._make_field(i) for i in fields ] + self.name: str = kwargs['name'] if 'name' in kwargs.keys() else '' + self.field_index: dict = {} + self._update_field_index() + for i in self.fields: + i.table = self + self.data: list = [] + + def _make_field(self, field): + return field if isinstance(field, Field) else Field(field) + + def _update_field_index(self): + for i, key in enumerate(self.fields): + self.field_index[key.name] = i + + def name(self, name: str): + self.name = name + return name + + def fields(self): + return self.fields + + def assign_field_alias(self, alias: str): + raise NotImplemented() + + def insert_row(self, *fields): + if len(fields) != len(self.fields): + EGEError(f'Wrong column count {len(fields)} != {len(self.fields)}') + self.data.append(fields) + return self + + def insert_rows(self, *rows): + for fields in rows: + self.insert_row(*fields) + return self + + def insert_column(self): + raise NotImplemented() + + def print_row(self, row): + raise NotImplemented() + + def print(self): + raise NotImplemented() + + def count(self): + return len(self.data) + + def _row_hash(self, row, env=None): + if env is None: + env = {} + for f in self.fields: + env[f.name] = row[self.field_index[f.name]] + return env + + def _hash(self): + env = { '&columns': { str(f): self.column_array(str(f)) for f in self.fields }, + '&': aggregate_function(), + '&count': self.count() } + return env + + def select(self, fields, where=None, p=None): + if not isinstance(fields, list): + fields = [fields] + ref, aggr, group, having = 0, 0, 0, 0 + if isinstance(ref, dict): + ref = p[ref] + group = p[group] + having = p[having] + + aggr = list(filter(lambda x: isinstance(x, CallFuncAggregate), fields)) + k = 0 + args = [] + exprs = [ f'expr_{i}' for i in range(1, 4) ] + for f in fields: + # TODO this seems to be wrong. Check later + if not isinstance(f, Field) and hasattr(f, '__call__') and f.__name__ in exprs: + k += 1 + args.append(f'expr_{k}') + else: + args.append(f) + result = Table(args) + values = [ f if hasattr(f, '__call__') else make_expr(f) for f in fields ] + calc_row = lambda x: [ i.run(x) for i in values ] + + tab_where = self.where(where, ref) + if group: + raise NotImplemented() + else: + ans = [] + env = tab_where._hash() + for data in tab_where.data: + ans.append(calc_row(tab_where._row_hash(data, env))) + result.data = [ ans[0] ] if aggr else ans #TODO check correctness in case if aggr == True + + return result + + def group_by(self): + raise NotImplemented() + + def where(self, where=None, ref=None): + """where: Union[callable, None], ref: Union[bool, None]""" + if where is None: + return self + table = Table(self.fields) + table.data = [ data if ref else data[:] for data in self.data if where(self._row_hash(data)) ] + return table + + def count_where(self): + raise NotImplemented() + + def update(self): + raise NotImplemented() + + def delete(self): + raise NotImplemented() + + def natural_join(self): + raise NotImplemented() + + def inner_join(self): + raise NotImplemented() + + def inner_join_expr(self): + raise NotImplemented() + + def table_html(self): + raise NotImplemented() + + def fetch_val(self): + raise NotImplemented() + + #TODO how best to implement getting the random number generator in table methods + def random_row(self, rnd: Random): + return rnd.pick(self.data) + + def _field_index(self, field: str): + """field: Union[int, str]""" + if isinstance(field, int): + if 1 <= field <= len(self.fields): + return field - 1 + else: + raise EGEError(f"Unknown field {field}") + if field not in self.field_index.keys(): + raise EGEError(f"Unknown field {field}") + return self.field_index[field] + + def column_array(self, field): + """field: Union[int, str]""" + column_idx = self._field_index(field) + return [ data[column_idx] for data in self.data ] + + def column_hash(self, field: int | str): + """field: Union[int, str]""" + column_idx = self._field_index(field) + r = {} + for row in self.data: + if row[column_idx] not in r.keys(): + r[row[column_idx]] = 0 + r[row[column_idx]] += 1 + return r diff --git a/EGE/SQL/Utils.py b/EGE/SQL/Utils.py new file mode 100644 index 0000000..a7a8b0f --- /dev/null +++ b/EGE/SQL/Utils.py @@ -0,0 +1,16 @@ +from EGE.Random import Random +# from EGE.SQL.Table import Aggregate + +def create_table(rnd: Random, row: int, col: int, name: str): + product = None + return product + +# def aggregate_function(name=None): +# """name: Union[str, None]""" +# aggr = Aggregate.__subclasses__() +# if name is not None: +# for sub in aggr: +# if name == sub.__name__: +# return sub +# return [sub for sub in aggr] + diff --git a/EGE/SQL/__init__.py b/EGE/SQL/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/EGE/Utils.py b/EGE/Utils.py index 5eb3cf8..bc2d71b 100644 --- a/EGE/Utils.py +++ b/EGE/Utils.py @@ -1,8 +1,14 @@ import math - -def aggregate_function(name): - # TODO - return None +from EGE.SQL.Aggregate import Aggregate + +def aggregate_function(name:str = None): + '''name: Union[str, None]''' + aggr = Aggregate.__subclasses__() + if name is not None: + for sub in aggr: + if name == sub.__name__: + return sub + return [ sub for sub in aggr ] def last_key(d, key): while key in d[key]: diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py new file mode 100644 index 0000000..c86f463 --- /dev/null +++ b/EGE/test/test_SQL.py @@ -0,0 +1,72 @@ +import unittest + +from EGE.GenBase import EGEError +from EGE.Random import Random + +if __name__ == '__main__': + import sys + sys.path.append('..') + from SQL.Table import Table + # from SQL.Utils import Utils + # from SQL.RandomTable import RandomTable +else: + from ..SQL.Table import Table + # from ..SQL.Utils import Utils + # from ..SQL.RandomTable import RandomTable + +rnd = Random(2342134) + +def join_sp(*elements): + res = ' '.join([str(el) for el in elements]) + return res + +def pack_table(table: Table): + res = '|'.join([ join_sp(*[ str(f) for f in table.fields ]), *[ join_sp(*d) for d in table.data ] ]) + return res + +def pack_table_sorted(table: Table): + return '|'.join([ join_sp(table.fields) ] + [ join_sp(d) for d in sorted(table.data) ]) + +class test_SQL(unittest.TestCase): + def test_CreateTable(self): + eq = self.assertEqual + + with self.assertRaisesRegex(EGEError, 'fields', msg='no fields'): + t = Table(None) #test seems to be useless because python throw exception if positional argument is missed + t = Table(['a', 'b', 'c'], name='table') + eq(t.name, 'table', 'table name') + + def test_SelectInsertHash(self): + eq = self.assertEqual + tab = Table('id name'.split()) + + self.assertDictEqual(tab._row_hash([ 2, 3 ]), { 'id': 2, 'name': 3 }, 'row hash') + tab.insert_rows([ 1, 'aaa' ], [ 2, 'bbb' ]) + eq(pack_table(tab.select([ 'id', 'name' ])), 'id name|1 aaa|2 bbb', 'all fields') + tab.insert_row(3, 'ccc') + eq(pack_table(tab.select('id')), 'id|1|2|3', 'field 1') + + self.assertListEqual(tab.column_array('id'), [ 1, 2, 3 ], 'column_array') + self.assertListEqual(tab.column_array(1), [ 1, 2, 3 ], 'column_array by number') + with self.assertRaisesRegex(EGEError, 'zzz', msg='column_array none'): + tab.column_array('zzz') + with self.assertRaisesRegex(EGEError, '77', msg='column_array by number none'): + tab.column_array(77) + + self.assertDictEqual(tab.column_hash('name'), { 'aaa': 1, 'bbb': 1, 'ccc': 1 }, 'column_hash') + self.assertDictEqual(tab.column_hash(2), { 'aaa': 1, 'bbb': 1, 'ccc': 1 }, 'column_hash by number') + with self.assertRaisesRegex(EGEError, 'xxx', msg='column_hash none'): + tab.column_hash('xxx') + + eq(pack_table(tab.select('name')), 'name|aaa|bbb|ccc', 'field 2') + eq(pack_table(tab.select([ 'id', 'id' ])), 'id id|1 1|2 2|3 3', 'duplicate field') + #TODO exception used to be raised from Prog::Var::run, but it removed from there. Need to restore it or choose new place to raise + # with self.assertRaisesRegex(EGEError, 'zzz', msg='bad field'): + # tab.select('zzz') + + r = tab.random_row(rnd)[0] + self.assertTrue(r == 1 or r == 2 or r == 3, 'random_row') + + +if __name__ == '__main__': + unittest.main(verbosity=1) \ No newline at end of file From fb2b771c06dfc27ea3a88cd2a7fc81b8e42881f1 Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Thu, 27 Jan 2022 13:12:28 +1000 Subject: [PATCH 2/7] Update gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 55e63fc..6f10ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.pyc t.xhtml t.py +.idea +.venv From 776aa88d5ae53d18ea4feaac8aab21276861875d Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Thu, 27 Jan 2022 15:44:34 +1000 Subject: [PATCH 3/7] Add more tests and add raise exception when table does not have field requested by select https://www.notion.so/SQL-494243006ef04f68b803fff9104e5597 --- EGE/SQL/Table.py | 22 +++++++++++++++------- EGE/test/test_SQL.py | 44 +++++++++++++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/EGE/SQL/Table.py b/EGE/SQL/Table.py index 5f7a04f..9955a2c 100644 --- a/EGE/SQL/Table.py +++ b/EGE/SQL/Table.py @@ -1,5 +1,5 @@ from EGE.GenBase import EGEError -from EGE.Prog import CallFuncAggregate, make_expr +from EGE.Prog import CallFuncAggregate, make_expr, Op from EGE.Random import Random from EGE.Utils import aggregate_function @@ -24,6 +24,14 @@ def __repr__(self): return f"Field({{'name': '{self.name}', 'name_alias': '{self.name_alias}'}})" return f"Field('{self.name}')" + def __eq__(self, other): + if isinstance(other, str): + return other == self.name or other == self.name_alias + elif isinstance(other, Field): + return other.name == self.name and other.name_alias == self.name_alias + raise EGEError("Imvalid type to compare with Field!") + + class Table: def __init__(self, fields: list, **kwargs): @@ -49,16 +57,13 @@ def name(self, name: str): self.name = name return name - def fields(self): - return self.fields - def assign_field_alias(self, alias: str): raise NotImplemented() def insert_row(self, *fields): if len(fields) != len(self.fields): EGEError(f'Wrong column count {len(fields)} != {len(self.fields)}') - self.data.append(fields) + self.data.append(list(fields)) return self def insert_rows(self, *rows): @@ -93,7 +98,10 @@ def _hash(self): def select(self, fields, where=None, p=None): if not isinstance(fields, list): - fields = [fields] + fields = [ fields ] + for f in fields: + if not isinstance(f, CallFuncAggregate) and f not in self.fields: + raise EGEError(f"Table '{self.name}' does not contain field '{f}'!") ref, aggr, group, having = 0, 0, 0, 0 if isinstance(ref, dict): ref = p[ref] @@ -182,7 +190,7 @@ def column_array(self, field): column_idx = self._field_index(field) return [ data[column_idx] for data in self.data ] - def column_hash(self, field: int | str): + def column_hash(self, field): """field: Union[int, str]""" column_idx = self._field_index(field) r = {} diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py index c86f463..6a4f395 100644 --- a/EGE/test/test_SQL.py +++ b/EGE/test/test_SQL.py @@ -28,7 +28,7 @@ def pack_table_sorted(table: Table): return '|'.join([ join_sp(table.fields) ] + [ join_sp(d) for d in sorted(table.data) ]) class test_SQL(unittest.TestCase): - def test_CreateTable(self): + def test_create_table(self): eq = self.assertEqual with self.assertRaisesRegex(EGEError, 'fields', msg='no fields'): @@ -36,33 +36,51 @@ def test_CreateTable(self): t = Table(['a', 'b', 'c'], name='table') eq(t.name, 'table', 'table name') - def test_SelectInsertHash(self): + def test_insert_row_copies(self): + eq = self.assertEqual + + tab = Table(['f'], name='table') + r = 1 + tab.insert_row(r) + r = 2 + eq('f|1', pack_table(tab), 'insert_row copies') + + def test_insert_rows_copies(self): + eq = self.assertEqual + + tab = Table(['f'], name='table') + r = [ 1 ] + tab.insert_rows(r, r) + r[0] = 2 + eq('f|1|1', pack_table(tab), 'insert_row copies') + + def test_select_insert_hash(self): eq = self.assertEqual tab = Table('id name'.split()) self.assertDictEqual(tab._row_hash([ 2, 3 ]), { 'id': 2, 'name': 3 }, 'row hash') tab.insert_rows([ 1, 'aaa' ], [ 2, 'bbb' ]) - eq(pack_table(tab.select([ 'id', 'name' ])), 'id name|1 aaa|2 bbb', 'all fields') + eq('id name|1 aaa|2 bbb', pack_table(tab.select([ 'id', 'name' ])), 'all fields') tab.insert_row(3, 'ccc') - eq(pack_table(tab.select('id')), 'id|1|2|3', 'field 1') + eq('id|1|2|3', pack_table(tab.select('id')), 'field 1') - self.assertListEqual(tab.column_array('id'), [ 1, 2, 3 ], 'column_array') - self.assertListEqual(tab.column_array(1), [ 1, 2, 3 ], 'column_array by number') + self.assertListEqual([ 1, 2, 3 ], tab.column_array('id'), 'column_array') + self.assertListEqual([ 1, 2, 3 ], tab.column_array(1), 'column_array by number') with self.assertRaisesRegex(EGEError, 'zzz', msg='column_array none'): tab.column_array('zzz') with self.assertRaisesRegex(EGEError, '77', msg='column_array by number none'): tab.column_array(77) - self.assertDictEqual(tab.column_hash('name'), { 'aaa': 1, 'bbb': 1, 'ccc': 1 }, 'column_hash') - self.assertDictEqual(tab.column_hash(2), { 'aaa': 1, 'bbb': 1, 'ccc': 1 }, 'column_hash by number') + self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash('name'), 'column_hash') + self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash(2), 'column_hash by number') with self.assertRaisesRegex(EGEError, 'xxx', msg='column_hash none'): tab.column_hash('xxx') - eq(pack_table(tab.select('name')), 'name|aaa|bbb|ccc', 'field 2') - eq(pack_table(tab.select([ 'id', 'id' ])), 'id id|1 1|2 2|3 3', 'duplicate field') - #TODO exception used to be raised from Prog::Var::run, but it removed from there. Need to restore it or choose new place to raise - # with self.assertRaisesRegex(EGEError, 'zzz', msg='bad field'): - # tab.select('zzz') + eq('name|aaa|bbb|ccc', pack_table(tab.select('name')), 'field 2') + eq('id id|1 1|2 2|3 3', pack_table(tab.select([ 'id', 'id' ])), 'duplicate field') + # TODO exception used to be raised from Prog::Var::run, but it removed from there. Need to restore it or choose new place to raise + with self.assertRaisesRegex(EGEError, 'zzz', msg='bad field'): + tab.select('zzz') r = tab.random_row(rnd)[0] self.assertTrue(r == 1 or r == 2 or r == 3, 'random_row') From 1ad4eddef36c9c270c4683849d0c9abca58e9d31 Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Thu, 27 Jan 2022 15:57:40 +1000 Subject: [PATCH 4/7] Add implementation of count_where and test for it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://www.notion.so/SQL-494243006ef04f68b803fff9104e5597 — also allow create table with empty fields --- EGE/SQL/Table.py | 13 ++++++++----- EGE/test/test_SQL.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/EGE/SQL/Table.py b/EGE/SQL/Table.py index 9955a2c..fd4fca5 100644 --- a/EGE/SQL/Table.py +++ b/EGE/SQL/Table.py @@ -35,7 +35,7 @@ def __eq__(self, other): class Table: def __init__(self, fields: list, **kwargs): - if not fields or fields is None: + if fields is None: raise EGEError('No fields') self.fields: list[Field] = [ self._make_field(i) for i in fields ] @@ -138,16 +138,19 @@ def select(self, fields, where=None, p=None): def group_by(self): raise NotImplemented() - def where(self, where=None, ref=None): + def where(self, where:Op=None, ref=None): """where: Union[callable, None], ref: Union[bool, None]""" if where is None: return self table = Table(self.fields) - table.data = [ data if ref else data[:] for data in self.data if where(self._row_hash(data)) ] + table.data = [ data if ref else data[:] for data in self.data if where.run(self._row_hash(data)) ] return table - def count_where(self): - raise NotImplemented() + def count_where(self, where: Op = None): + if where is None: + return self.count() + res = list(filter(lambda x: where.run(self._row_hash(x)), self.data)) + return len(res) def update(self): raise NotImplemented() diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py index 6a4f395..0b7afa0 100644 --- a/EGE/test/test_SQL.py +++ b/EGE/test/test_SQL.py @@ -2,6 +2,7 @@ from EGE.GenBase import EGEError from EGE.Random import Random +from EGE.Prog import make_expr if __name__ == '__main__': import sys @@ -85,6 +86,20 @@ def test_select_insert_hash(self): r = tab.random_row(rnd)[0] self.assertTrue(r == 1 or r == 2 or r == 3, 'random_row') + def test_select_where(self): + eq = self.assertEqual + + tab = Table('id name city'.split()) + tab.insert_rows([ 1, 'aaa', 3 ], [ 2, 'bbb', 2 ], [ 3, 'ccc', 1 ], [ 4, 'bbn', 2 ]) + e = make_expr([ '==', 'city', 2 ]) + eq('id name city|2 bbb 2|4 bbn 2', pack_table(tab.where(e)), 'where city == 2') + eq(2, tab.count_where(e), 'count_where city == 2') + eq('id name|2 bbb|4 bbn', pack_table(tab.select(['id', 'name'], e)), 'select id, name where city == 2') + eq('||', pack_table(tab.select([], e)), 'select where city == 2') + eq(4, tab.count(), 'count') + eq(0, tab.where(make_expr(0)).count(), 'where false') + eq(0, tab.count_where(make_expr(0)), 'count_where false') + if __name__ == '__main__': unittest.main(verbosity=1) \ No newline at end of file From 0f8368c92a5cbb36e24233644ad283d344f5eabe Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Fri, 28 Jan 2022 01:50:49 +1000 Subject: [PATCH 5/7] Add test for where operation by ref https://www.notion.so/SQL-494243006ef04f68b803fff9104e5597 --- EGE/SQL/Table.py | 2 +- EGE/test/test_SQL.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/EGE/SQL/Table.py b/EGE/SQL/Table.py index fd4fca5..a6edd5e 100644 --- a/EGE/SQL/Table.py +++ b/EGE/SQL/Table.py @@ -138,7 +138,7 @@ def select(self, fields, where=None, p=None): def group_by(self): raise NotImplemented() - def where(self, where:Op=None, ref=None): + def where(self, where: Op = None, ref: bool = None): """where: Union[callable, None], ref: Union[bool, None]""" if where is None: return self diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py index 0b7afa0..bfac0bb 100644 --- a/EGE/test/test_SQL.py +++ b/EGE/test/test_SQL.py @@ -3,6 +3,7 @@ from EGE.GenBase import EGEError from EGE.Random import Random from EGE.Prog import make_expr +from EGE.Utils import nrange if __name__ == '__main__': import sys @@ -100,6 +101,18 @@ def test_select_where(self): eq(0, tab.where(make_expr(0)).count(), 'where false') eq(0, tab.count_where(make_expr(0)), 'count_where false') + def test_where_copy_ref(self): + eq = self.assertEqual + + tab = Table([ 'id' ]) + tab.insert_rows(*[ [i] for i in range(1, 6) ]) + e = make_expr([ '==', 'id', '3' ]) + w2 = tab.where(e) + w2.data[0][0] = 9 + eq('id|1|2|3|4|5', pack_table(tab), 'where copy') + w1 = tab.where(e, True) + w1.data[0][0] = 9 + eq('id|1|2|9|4|5', pack_table(tab), 'where ref') if __name__ == '__main__': unittest.main(verbosity=1) \ No newline at end of file From 1f4773ca4afc3f36d6d91a725f77ffb9e1de13c8 Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Fri, 28 Jan 2022 02:20:12 +1000 Subject: [PATCH 6/7] Add using of subTest to improve the look of tests run results https://www.notion.so/SQL-494243006ef04f68b803fff9104e5597 --- EGE/test/test_SQL.py | 110 +++++++++++++++++++++++++++++++------------ 1 file changed, 79 insertions(+), 31 deletions(-) diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py index bfac0bb..75d589d 100644 --- a/EGE/test/test_SQL.py +++ b/EGE/test/test_SQL.py @@ -29,6 +29,7 @@ def pack_table(table: Table): def pack_table_sorted(table: Table): return '|'.join([ join_sp(table.fields) ] + [ join_sp(d) for d in sorted(table.data) ]) +#TODO reorganize tests to multiple test cases to simplify structure of complex test_methods class test_SQL(unittest.TestCase): def test_create_table(self): eq = self.assertEqual @@ -60,32 +61,58 @@ def test_select_insert_hash(self): eq = self.assertEqual tab = Table('id name'.split()) - self.assertDictEqual(tab._row_hash([ 2, 3 ]), { 'id': 2, 'name': 3 }, 'row hash') + with self.subTest(msg='test _row_hash'): + self.assertDictEqual(tab._row_hash([ 2, 3 ]), { 'id': 2, 'name': 3 }) + tab.insert_rows([ 1, 'aaa' ], [ 2, 'bbb' ]) - eq('id name|1 aaa|2 bbb', pack_table(tab.select([ 'id', 'name' ])), 'all fields') + with self.subTest(msg='select all fields'): + eq('id name|1 aaa|2 bbb', pack_table(tab.select([ 'id', 'name' ]))) + tab.insert_row(3, 'ccc') - eq('id|1|2|3', pack_table(tab.select('id')), 'field 1') - - self.assertListEqual([ 1, 2, 3 ], tab.column_array('id'), 'column_array') - self.assertListEqual([ 1, 2, 3 ], tab.column_array(1), 'column_array by number') - with self.assertRaisesRegex(EGEError, 'zzz', msg='column_array none'): - tab.column_array('zzz') - with self.assertRaisesRegex(EGEError, '77', msg='column_array by number none'): - tab.column_array(77) - - self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash('name'), 'column_hash') - self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash(2), 'column_hash by number') - with self.assertRaisesRegex(EGEError, 'xxx', msg='column_hash none'): - tab.column_hash('xxx') - - eq('name|aaa|bbb|ccc', pack_table(tab.select('name')), 'field 2') - eq('id id|1 1|2 2|3 3', pack_table(tab.select([ 'id', 'id' ])), 'duplicate field') - # TODO exception used to be raised from Prog::Var::run, but it removed from there. Need to restore it or choose new place to raise - with self.assertRaisesRegex(EGEError, 'zzz', msg='bad field'): - tab.select('zzz') + with self.subTest(msg='test select field id'): + eq('id|1|2|3', pack_table(tab.select('id'))) + + with self.subTest(msg='test column_array (str arg)'): + self.assertListEqual([ 1, 2, 3 ], tab.column_array('id')) + + with self.subTest(msg='test column_array (number arg)'): + self.assertListEqual([ 1, 2, 3 ], tab.column_array(1)) + + with self.subTest(msg='test column_array none (str arg)'): + with self.assertRaisesRegex(EGEError, 'zzz', msg='column_array none'): + tab.column_array('zzz') + + with self.subTest(msg='test column_array none (number arg)'): + with self.assertRaisesRegex(EGEError, '77', msg='column_array by number none'): + tab.column_array(77) + + with self.subTest(msg='test column_hash (str arg)'): + self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash('name')) + + with self.subTest(msg='test column_hash (number arg)'): + self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash(2)) + + with self.subTest(msg='test column_hash none (str argument)'): + with self.assertRaisesRegex(EGEError, 'xxx'): + tab.column_hash('xxx') + + with self.subTest(msg='test column_hash none (number argument)'): + with self.assertRaisesRegex(EGEError, '42'): + tab.column_hash(42) + + with self.subTest(msg='test select field name'): + eq('name|aaa|bbb|ccc', pack_table(tab.select('name'))) + + with self.subTest(msg='test select two same fields'): + eq('id id|1 1|2 2|3 3', pack_table(tab.select([ 'id', 'id' ]))) + + with self.subTest(msg='test select non existing field'): + with self.assertRaisesRegex(EGEError, 'zzz'): + tab.select('zzz') r = tab.random_row(rnd)[0] - self.assertTrue(r == 1 or r == 2 or r == 3, 'random_row') + with self.subTest(msg='test random row'): + self.assertTrue(r == 1 or r == 2 or r == 3) def test_select_where(self): eq = self.assertEqual @@ -93,13 +120,27 @@ def test_select_where(self): tab = Table('id name city'.split()) tab.insert_rows([ 1, 'aaa', 3 ], [ 2, 'bbb', 2 ], [ 3, 'ccc', 1 ], [ 4, 'bbn', 2 ]) e = make_expr([ '==', 'city', 2 ]) - eq('id name city|2 bbb 2|4 bbn 2', pack_table(tab.where(e)), 'where city == 2') - eq(2, tab.count_where(e), 'count_where city == 2') - eq('id name|2 bbb|4 bbn', pack_table(tab.select(['id', 'name'], e)), 'select id, name where city == 2') - eq('||', pack_table(tab.select([], e)), 'select where city == 2') - eq(4, tab.count(), 'count') - eq(0, tab.where(make_expr(0)).count(), 'where false') - eq(0, tab.count_where(make_expr(0)), 'count_where false') + + with self.subTest(msg='test where city == 2'): + eq('id name city|2 bbb 2|4 bbn 2', pack_table(tab.where(e))) + + with self.subTest(msg='test count_where city == 2'): + eq(2, tab.count_where(e)) + + with self.subTest(msg='select id, name where city == 2'): + eq('id name|2 bbb|4 bbn', pack_table(tab.select(['id', 'name'], e))) + + with self.subTest(msg='test select where city == 2'): + eq('||', pack_table(tab.select([], e))) + + with self.subTest(msg='test count rows'): + eq(4, tab.count()) + + with self.subTest(msg='test where false'): + eq(0, tab.where(make_expr(0)).count()) + + with self.subTest(msg='test count_where false'): + eq(0, tab.count_where(make_expr(0))) def test_where_copy_ref(self): eq = self.assertEqual @@ -107,12 +148,19 @@ def test_where_copy_ref(self): tab = Table([ 'id' ]) tab.insert_rows(*[ [i] for i in range(1, 6) ]) e = make_expr([ '==', 'id', '3' ]) + w2 = tab.where(e) w2.data[0][0] = 9 - eq('id|1|2|3|4|5', pack_table(tab), 'where copy') + with self.subTest(msg='test where copy'): + eq('id|1|2|3|4|5', pack_table(tab)) + w1 = tab.where(e, True) w1.data[0][0] = 9 - eq('id|1|2|9|4|5', pack_table(tab), 'where ref') + with self.subTest(msg='test where ref'): + eq('id|1|2|9|4|5', pack_table(tab)) + + # def test_update_var(self): + if __name__ == '__main__': unittest.main(verbosity=1) \ No newline at end of file From 0fc584b8fcd4f081c445f3621eb905660e63349a Mon Sep 17 00:00:00 2001 From: Vladimir Glushkov Date: Fri, 28 Jan 2022 21:18:23 +1000 Subject: [PATCH 7/7] Add Update query, update method to Table and tests https://www.notion.so/SQL-494243006ef04f68b803fff9104e5597 --- EGE/SQL/Queries.py | 81 ++++++++++++++++++++++++++++++++++++++++++++ EGE/SQL/Table.py | 16 +++++---- EGE/test/test_SQL.py | 33 ++++++++++++++++-- 3 files changed, 120 insertions(+), 10 deletions(-) create mode 100644 EGE/SQL/Queries.py diff --git a/EGE/SQL/Queries.py b/EGE/SQL/Queries.py new file mode 100644 index 0000000..16dffb5 --- /dev/null +++ b/EGE/SQL/Queries.py @@ -0,0 +1,81 @@ +from abc import ABC, abstractmethod + +import EGE.Html as html +from EGE.GenBase import EGEError +from EGE.SQL.Table import Table, Field +from EGE.Prog import SynElement, Block + +class Query(ABC): + def __init__(self, table, *, + where: SynElement = None, + having: SynElement = None, + group_by: list[SynElement] = None): + """table: Union[Table, str]""" + if table is None: + raise EGEError('Table is none!') + + self.table = table if isinstance(table, Table) else None + self.table_name = table.name if isinstance(table, Table) else table + + self.where: SynElement = where + self.having: SynElement = having + self.group_by: list[SynElement] = group_by + + @abstractmethod + def text(self, opts: dict): + """opts: dict[str, Any]""" + pass + + @abstractmethod + def run(self): + pass + + def text_html(self): + return self.text({ 'html': 1 }) + + def text_html_tt(self): + return html.tag('tt', self.text({ 'html': 1 })) + + def _field_list_sql(self, fields: list, opts: dict): + """ + fields: list[SynElement | str] + opts: dict[str, Any] + """ + return ', '.join([ f.to_lang_named('SQL', opts) + if issubclass(type(f), SynElement) else f + for f in fields ]) + + def where_sql(self, opts: dict): + return f"WHERE {self.where.to_lang_named('SQL', opts)}" if self.where else '' + + def having_sql(self, opts: dict): + return f"HAVING {self.having.to_lang_named('SQL', opts)}" if self.having else '' + + def group_by_sql(self, opts: dict): + return f"GROUP BY {self._field_list_sql(self.group_by, opts)}" if self.group_by else '' + + def _maybe_run(self): + with getattr(self, 'run', None) as run: + if callable(run): + return run() + return self + +class Update(Query): + def __init__(self, table, assigns: Block, *, where: SynElement = None): + if assigns is None: + raise EGEError('Assigns is none!') + super(Update, self).__init__(table, where=where) + self.assigns: Block = assigns + + def run(self): + self.table.update(self.assigns, self.where) + # TODO check if need return table + + def text(self, opts: dict): + """opts: dict[str, Any]""" + assigns = self.assigns.to_lang_named('SQL') + where_sql = self.where_sql(opts) + if where_sql: + return f"UPDATE {self.table_name} SET {assigns} {where_sql}" + else: + return f"UPDATE {self.table_name} SET {assigns}" diff --git a/EGE/SQL/Table.py b/EGE/SQL/Table.py index a6edd5e..e67ce52 100644 --- a/EGE/SQL/Table.py +++ b/EGE/SQL/Table.py @@ -1,5 +1,5 @@ from EGE.GenBase import EGEError -from EGE.Prog import CallFuncAggregate, make_expr, Op +from EGE.Prog import CallFuncAggregate, make_expr, Op, SynElement, Block from EGE.Random import Random from EGE.Utils import aggregate_function @@ -53,10 +53,6 @@ def _update_field_index(self): for i, key in enumerate(self.fields): self.field_index[key.name] = i - def name(self, name: str): - self.name = name - return name - def assign_field_alias(self, alias: str): raise NotImplemented() @@ -152,8 +148,14 @@ def count_where(self, where: Op = None): res = list(filter(lambda x: where.run(self._row_hash(x)), self.data)) return len(res) - def update(self): - raise NotImplemented() + def update(self, assigns: Block, where: Op = None): + data = self.data if where is None else self.where(where, True) + for row in data: + row_hash = self._row_hash(row) + assigns.run(row_hash) + for f in self.fields: + row[self._field_index(f.name)] = row_hash[f.name] + return self def delete(self): raise NotImplemented() diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py index 75d589d..689b8f3 100644 --- a/EGE/test/test_SQL.py +++ b/EGE/test/test_SQL.py @@ -2,17 +2,19 @@ from EGE.GenBase import EGEError from EGE.Random import Random -from EGE.Prog import make_expr +from EGE.Prog import make_expr, make_block from EGE.Utils import nrange if __name__ == '__main__': import sys sys.path.append('..') from SQL.Table import Table + from SQL.Queries import * # from SQL.Utils import Utils # from SQL.RandomTable import RandomTable else: - from ..SQL.Table import Table + from EGE.SQL.Table import Table + from EGE.SQL.Queries import * # from ..SQL.Utils import Utils # from ..SQL.RandomTable import RandomTable @@ -159,8 +161,33 @@ def test_where_copy_ref(self): with self.subTest(msg='test where ref'): eq('id|1|2|9|4|5', pack_table(tab)) - # def test_update_var(self): + def test_update(self): + eq = self.assertEqual + tab = Table('a b'.split()) + tab.insert_rows([ 1, 2 ], [ 3, 4 ], [ 5, 6 ]) + tab.update(make_block('= a b'.split())) + with self.subTest(msg='test update with var'): + eq('a b|2 2|4 4|6 6', pack_table(tab)) + + tab.update(make_block([ '=', 'a', [ '+', 'a', '1' ] ])) + with self.subTest(msg='test update with expr'): + eq('a b|3 2|5 4|7 6', pack_table(tab)) + + Update(tab, make_block([ '=', 'b', 'a' ])).run() + with self.subTest(msg='test update query'): + eq('a b|3 3|5 5|7 7', pack_table(tab)) + + with self.subTest(msg='test update query none'): + with self.assertRaisesRegex(ValueError, 'none'): + # TODO need to do something with it!!! + # run of block below must be terminated with exception + # because there is no 'none' field/variable + # but in prog module there is an error: + # it isn't fails when variable is undefined + # raise of ValueError was removed in + # https://github.com/klenin/EGEpy/commit/d5d9b0ec9ff4224f62d00708fc865bee5741cf5b + Update(tab, make_block(['=', 'b', 'none'])).run() if __name__ == '__main__': unittest.main(verbosity=1) \ No newline at end of file