diff --git a/.gitignore b/.gitignore index 55e63fc..6f10ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.pyc t.xhtml t.py +.idea +.venv diff --git a/EGE/SQL/Aggregate.py b/EGE/SQL/Aggregate.py new file mode 100644 index 0000000..5a61dab --- /dev/null +++ b/EGE/SQL/Aggregate.py @@ -0,0 +1,29 @@ +class Aggregate: + @staticmethod + def call(self): # abstract + pass + +class Count(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Sum(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Avg(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Min(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() + +class Max(Aggregate): + @staticmethod + def call(self): + raise NotImplemented() \ No newline at end of file diff --git a/EGE/SQL/Queries.py b/EGE/SQL/Queries.py new file mode 100644 index 0000000..16dffb5 --- /dev/null +++ b/EGE/SQL/Queries.py @@ -0,0 +1,81 @@ +from abc import ABC, abstractmethod + +import EGE.Html as html +from EGE.GenBase import EGEError +from EGE.SQL.Table import Table, Field +from EGE.Prog import SynElement, Block + +class Query(ABC): + def __init__(self, table, *, + where: SynElement = None, + having: SynElement = None, + group_by: list[SynElement] = None): + """table: Union[Table, str]""" + if table is None: + raise EGEError('Table is none!') + + self.table = table if isinstance(table, Table) else None + self.table_name = table.name if isinstance(table, Table) else table + + self.where: SynElement = where + self.having: SynElement = having + self.group_by: list[SynElement] = group_by + + @abstractmethod + def text(self, opts: dict): + """opts: dict[str, Any]""" + pass + + @abstractmethod + def run(self): + pass + + def text_html(self): + return self.text({ 'html': 1 }) + + def text_html_tt(self): + return html.tag('tt', self.text({ 'html': 1 })) + + def _field_list_sql(self, fields: list, opts: dict): + """ + fields: list[SynElement | str] + opts: dict[str, Any] + """ + return ', '.join([ f.to_lang_named('SQL', opts) + if issubclass(type(f), SynElement) else f + for f in fields ]) + + def where_sql(self, opts: dict): + return f"WHERE {self.where.to_lang_named('SQL', opts)}" if self.where else '' + + def having_sql(self, opts: dict): + return f"HAVING {self.having.to_lang_named('SQL', opts)}" if self.having else '' + + def group_by_sql(self, opts: dict): + return f"GROUP BY {self._field_list_sql(self.group_by, opts)}" if self.group_by else '' + + def _maybe_run(self): + with getattr(self, 'run', None) as run: + if callable(run): + return run() + return self + +class Update(Query): + def __init__(self, table, assigns: Block, *, where: SynElement = None): + if assigns is None: + raise EGEError('Assigns is none!') + super(Update, self).__init__(table, where=where) + self.assigns: Block = assigns + + def run(self): + self.table.update(self.assigns, self.where) + # TODO check if need return table + + def text(self, opts: dict): + """opts: dict[str, Any]""" + assigns = self.assigns.to_lang_named('SQL') + where_sql = self.where_sql(opts) + if where_sql: + return f"UPDATE {self.table_name} SET {assigns} {where_sql}" + else: + return f"UPDATE {self.table_name} SET {assigns}" diff --git a/EGE/SQL/RandomTable.py b/EGE/SQL/RandomTable.py new file mode 100644 index 0000000..5656106 --- /dev/null +++ b/EGE/SQL/RandomTable.py @@ -0,0 +1,69 @@ +from EGE.Random import Random + +class BaseTable: + columns: list + def __init__(self, rnd: Random, col_count, row_count, name): + self.rnd = rnd + self.columns = self.get_columns() + self.fields = [ + self.columns[0], + rnd.pick_n( + col_count - 1, + self.columns[1:len(self.columns)] + ) + ] + row_sources = [ row for row in self.get_rows_array() + if len(row) >= row_count ] + self.rows = rnd.pick_n(row_count, row_sources) + + + def get_columns(self) -> list: + return [] + + def get_rows_array(self): + pass + +class Products(BaseTable): + pass + +class Jobs(BaseTable): + pass + +class SalesMonth(BaseTable): + pass + +class Cities(BaseTable): + pass + +class People(BaseTable): + pass + +class Subjects(BaseTable): + pass + +class Marks(BaseTable): + pass + +class ParticipantsMonth(BaseTable): + pass + +def ok_table(table: BaseTable, rows, cols): + t_cols = table.get_columns() + t_rows = table.get_rows_array() + return t_cols >= cols and any([len(t_rows) > rows]) + +def pick(rnd: Random, *args): + return rnd.pick([ + lambda *args2: Products(args2), + lambda *args2: Jobs(args2), + lambda *args2: SalesMonth(args2), + lambda *args2: Cities(args2), + lambda *args2: People(args2), + lambda *args2: Subjects(args2), + lambda *args2: Marks(args2), + lambda *args2: ParticipantsMonth(args2), + ])(args) + +def create_table(rnd: Random, rows: int, columns: int): + table = pick(rnd, rows, columns) + return table diff --git a/EGE/SQL/Table.py b/EGE/SQL/Table.py new file mode 100644 index 0000000..e67ce52 --- /dev/null +++ b/EGE/SQL/Table.py @@ -0,0 +1,206 @@ +from EGE.GenBase import EGEError +from EGE.Prog import CallFuncAggregate, make_expr, Op, SynElement, Block +from EGE.Random import Random +from EGE.Utils import aggregate_function + +class Field: + def __init__(self, attr): + self.name: str = '' + self.name_alias: str = '' + if isinstance(attr, dict): + self.name = attr['name'] + self.name_alias = attr['name_alias'] + else: + self.name = attr + + def to_lang(self): + return self.name_alias + '.' + self.name if self.name_alias else self.name + + def __str__(self): + return self.name + + def __repr__(self): + if self.name_alias: + return f"Field({{'name': '{self.name}', 'name_alias': '{self.name_alias}'}})" + return f"Field('{self.name}')" + + def __eq__(self, other): + if isinstance(other, str): + return other == self.name or other == self.name_alias + elif isinstance(other, Field): + return other.name == self.name and other.name_alias == self.name_alias + raise EGEError("Imvalid type to compare with Field!") + + + +class Table: + def __init__(self, fields: list, **kwargs): + if fields is None: + raise EGEError('No fields') + + self.fields: list[Field] = [ self._make_field(i) for i in fields ] + self.name: str = kwargs['name'] if 'name' in kwargs.keys() else '' + self.field_index: dict = {} + self._update_field_index() + for i in self.fields: + i.table = self + self.data: list = [] + + def _make_field(self, field): + return field if isinstance(field, Field) else Field(field) + + def _update_field_index(self): + for i, key in enumerate(self.fields): + self.field_index[key.name] = i + + def assign_field_alias(self, alias: str): + raise NotImplemented() + + def insert_row(self, *fields): + if len(fields) != len(self.fields): + EGEError(f'Wrong column count {len(fields)} != {len(self.fields)}') + self.data.append(list(fields)) + return self + + def insert_rows(self, *rows): + for fields in rows: + self.insert_row(*fields) + return self + + def insert_column(self): + raise NotImplemented() + + def print_row(self, row): + raise NotImplemented() + + def print(self): + raise NotImplemented() + + def count(self): + return len(self.data) + + def _row_hash(self, row, env=None): + if env is None: + env = {} + for f in self.fields: + env[f.name] = row[self.field_index[f.name]] + return env + + def _hash(self): + env = { '&columns': { str(f): self.column_array(str(f)) for f in self.fields }, + '&': aggregate_function(), + '&count': self.count() } + return env + + def select(self, fields, where=None, p=None): + if not isinstance(fields, list): + fields = [ fields ] + for f in fields: + if not isinstance(f, CallFuncAggregate) and f not in self.fields: + raise EGEError(f"Table '{self.name}' does not contain field '{f}'!") + ref, aggr, group, having = 0, 0, 0, 0 + if isinstance(ref, dict): + ref = p[ref] + group = p[group] + having = p[having] + + aggr = list(filter(lambda x: isinstance(x, CallFuncAggregate), fields)) + k = 0 + args = [] + exprs = [ f'expr_{i}' for i in range(1, 4) ] + for f in fields: + # TODO this seems to be wrong. Check later + if not isinstance(f, Field) and hasattr(f, '__call__') and f.__name__ in exprs: + k += 1 + args.append(f'expr_{k}') + else: + args.append(f) + result = Table(args) + values = [ f if hasattr(f, '__call__') else make_expr(f) for f in fields ] + calc_row = lambda x: [ i.run(x) for i in values ] + + tab_where = self.where(where, ref) + if group: + raise NotImplemented() + else: + ans = [] + env = tab_where._hash() + for data in tab_where.data: + ans.append(calc_row(tab_where._row_hash(data, env))) + result.data = [ ans[0] ] if aggr else ans #TODO check correctness in case if aggr == True + + return result + + def group_by(self): + raise NotImplemented() + + def where(self, where: Op = None, ref: bool = None): + """where: Union[callable, None], ref: Union[bool, None]""" + if where is None: + return self + table = Table(self.fields) + table.data = [ data if ref else data[:] for data in self.data if where.run(self._row_hash(data)) ] + return table + + def count_where(self, where: Op = None): + if where is None: + return self.count() + res = list(filter(lambda x: where.run(self._row_hash(x)), self.data)) + return len(res) + + def update(self, assigns: Block, where: Op = None): + data = self.data if where is None else self.where(where, True) + for row in data: + row_hash = self._row_hash(row) + assigns.run(row_hash) + for f in self.fields: + row[self._field_index(f.name)] = row_hash[f.name] + return self + + def delete(self): + raise NotImplemented() + + def natural_join(self): + raise NotImplemented() + + def inner_join(self): + raise NotImplemented() + + def inner_join_expr(self): + raise NotImplemented() + + def table_html(self): + raise NotImplemented() + + def fetch_val(self): + raise NotImplemented() + + #TODO how best to implement getting the random number generator in table methods + def random_row(self, rnd: Random): + return rnd.pick(self.data) + + def _field_index(self, field: str): + """field: Union[int, str]""" + if isinstance(field, int): + if 1 <= field <= len(self.fields): + return field - 1 + else: + raise EGEError(f"Unknown field {field}") + if field not in self.field_index.keys(): + raise EGEError(f"Unknown field {field}") + return self.field_index[field] + + def column_array(self, field): + """field: Union[int, str]""" + column_idx = self._field_index(field) + return [ data[column_idx] for data in self.data ] + + def column_hash(self, field): + """field: Union[int, str]""" + column_idx = self._field_index(field) + r = {} + for row in self.data: + if row[column_idx] not in r.keys(): + r[row[column_idx]] = 0 + r[row[column_idx]] += 1 + return r diff --git a/EGE/SQL/Utils.py b/EGE/SQL/Utils.py new file mode 100644 index 0000000..a7a8b0f --- /dev/null +++ b/EGE/SQL/Utils.py @@ -0,0 +1,16 @@ +from EGE.Random import Random +# from EGE.SQL.Table import Aggregate + +def create_table(rnd: Random, row: int, col: int, name: str): + product = None + return product + +# def aggregate_function(name=None): +# """name: Union[str, None]""" +# aggr = Aggregate.__subclasses__() +# if name is not None: +# for sub in aggr: +# if name == sub.__name__: +# return sub +# return [sub for sub in aggr] + diff --git a/EGE/SQL/__init__.py b/EGE/SQL/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/EGE/Utils.py b/EGE/Utils.py index 5eb3cf8..bc2d71b 100644 --- a/EGE/Utils.py +++ b/EGE/Utils.py @@ -1,8 +1,14 @@ import math - -def aggregate_function(name): - # TODO - return None +from EGE.SQL.Aggregate import Aggregate + +def aggregate_function(name:str = None): + '''name: Union[str, None]''' + aggr = Aggregate.__subclasses__() + if name is not None: + for sub in aggr: + if name == sub.__name__: + return sub + return [ sub for sub in aggr ] def last_key(d, key): while key in d[key]: diff --git a/EGE/test/test_SQL.py b/EGE/test/test_SQL.py new file mode 100644 index 0000000..689b8f3 --- /dev/null +++ b/EGE/test/test_SQL.py @@ -0,0 +1,193 @@ +import unittest + +from EGE.GenBase import EGEError +from EGE.Random import Random +from EGE.Prog import make_expr, make_block +from EGE.Utils import nrange + +if __name__ == '__main__': + import sys + sys.path.append('..') + from SQL.Table import Table + from SQL.Queries import * + # from SQL.Utils import Utils + # from SQL.RandomTable import RandomTable +else: + from EGE.SQL.Table import Table + from EGE.SQL.Queries import * + # from ..SQL.Utils import Utils + # from ..SQL.RandomTable import RandomTable + +rnd = Random(2342134) + +def join_sp(*elements): + res = ' '.join([str(el) for el in elements]) + return res + +def pack_table(table: Table): + res = '|'.join([ join_sp(*[ str(f) for f in table.fields ]), *[ join_sp(*d) for d in table.data ] ]) + return res + +def pack_table_sorted(table: Table): + return '|'.join([ join_sp(table.fields) ] + [ join_sp(d) for d in sorted(table.data) ]) + +#TODO reorganize tests to multiple test cases to simplify structure of complex test_methods +class test_SQL(unittest.TestCase): + def test_create_table(self): + eq = self.assertEqual + + with self.assertRaisesRegex(EGEError, 'fields', msg='no fields'): + t = Table(None) #test seems to be useless because python throw exception if positional argument is missed + t = Table(['a', 'b', 'c'], name='table') + eq(t.name, 'table', 'table name') + + def test_insert_row_copies(self): + eq = self.assertEqual + + tab = Table(['f'], name='table') + r = 1 + tab.insert_row(r) + r = 2 + eq('f|1', pack_table(tab), 'insert_row copies') + + def test_insert_rows_copies(self): + eq = self.assertEqual + + tab = Table(['f'], name='table') + r = [ 1 ] + tab.insert_rows(r, r) + r[0] = 2 + eq('f|1|1', pack_table(tab), 'insert_row copies') + + def test_select_insert_hash(self): + eq = self.assertEqual + tab = Table('id name'.split()) + + with self.subTest(msg='test _row_hash'): + self.assertDictEqual(tab._row_hash([ 2, 3 ]), { 'id': 2, 'name': 3 }) + + tab.insert_rows([ 1, 'aaa' ], [ 2, 'bbb' ]) + with self.subTest(msg='select all fields'): + eq('id name|1 aaa|2 bbb', pack_table(tab.select([ 'id', 'name' ]))) + + tab.insert_row(3, 'ccc') + with self.subTest(msg='test select field id'): + eq('id|1|2|3', pack_table(tab.select('id'))) + + with self.subTest(msg='test column_array (str arg)'): + self.assertListEqual([ 1, 2, 3 ], tab.column_array('id')) + + with self.subTest(msg='test column_array (number arg)'): + self.assertListEqual([ 1, 2, 3 ], tab.column_array(1)) + + with self.subTest(msg='test column_array none (str arg)'): + with self.assertRaisesRegex(EGEError, 'zzz', msg='column_array none'): + tab.column_array('zzz') + + with self.subTest(msg='test column_array none (number arg)'): + with self.assertRaisesRegex(EGEError, '77', msg='column_array by number none'): + tab.column_array(77) + + with self.subTest(msg='test column_hash (str arg)'): + self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash('name')) + + with self.subTest(msg='test column_hash (number arg)'): + self.assertDictEqual({ 'aaa': 1, 'bbb': 1, 'ccc': 1 }, tab.column_hash(2)) + + with self.subTest(msg='test column_hash none (str argument)'): + with self.assertRaisesRegex(EGEError, 'xxx'): + tab.column_hash('xxx') + + with self.subTest(msg='test column_hash none (number argument)'): + with self.assertRaisesRegex(EGEError, '42'): + tab.column_hash(42) + + with self.subTest(msg='test select field name'): + eq('name|aaa|bbb|ccc', pack_table(tab.select('name'))) + + with self.subTest(msg='test select two same fields'): + eq('id id|1 1|2 2|3 3', pack_table(tab.select([ 'id', 'id' ]))) + + with self.subTest(msg='test select non existing field'): + with self.assertRaisesRegex(EGEError, 'zzz'): + tab.select('zzz') + + r = tab.random_row(rnd)[0] + with self.subTest(msg='test random row'): + self.assertTrue(r == 1 or r == 2 or r == 3) + + def test_select_where(self): + eq = self.assertEqual + + tab = Table('id name city'.split()) + tab.insert_rows([ 1, 'aaa', 3 ], [ 2, 'bbb', 2 ], [ 3, 'ccc', 1 ], [ 4, 'bbn', 2 ]) + e = make_expr([ '==', 'city', 2 ]) + + with self.subTest(msg='test where city == 2'): + eq('id name city|2 bbb 2|4 bbn 2', pack_table(tab.where(e))) + + with self.subTest(msg='test count_where city == 2'): + eq(2, tab.count_where(e)) + + with self.subTest(msg='select id, name where city == 2'): + eq('id name|2 bbb|4 bbn', pack_table(tab.select(['id', 'name'], e))) + + with self.subTest(msg='test select where city == 2'): + eq('||', pack_table(tab.select([], e))) + + with self.subTest(msg='test count rows'): + eq(4, tab.count()) + + with self.subTest(msg='test where false'): + eq(0, tab.where(make_expr(0)).count()) + + with self.subTest(msg='test count_where false'): + eq(0, tab.count_where(make_expr(0))) + + def test_where_copy_ref(self): + eq = self.assertEqual + + tab = Table([ 'id' ]) + tab.insert_rows(*[ [i] for i in range(1, 6) ]) + e = make_expr([ '==', 'id', '3' ]) + + w2 = tab.where(e) + w2.data[0][0] = 9 + with self.subTest(msg='test where copy'): + eq('id|1|2|3|4|5', pack_table(tab)) + + w1 = tab.where(e, True) + w1.data[0][0] = 9 + with self.subTest(msg='test where ref'): + eq('id|1|2|9|4|5', pack_table(tab)) + + def test_update(self): + eq = self.assertEqual + + tab = Table('a b'.split()) + tab.insert_rows([ 1, 2 ], [ 3, 4 ], [ 5, 6 ]) + tab.update(make_block('= a b'.split())) + with self.subTest(msg='test update with var'): + eq('a b|2 2|4 4|6 6', pack_table(tab)) + + tab.update(make_block([ '=', 'a', [ '+', 'a', '1' ] ])) + with self.subTest(msg='test update with expr'): + eq('a b|3 2|5 4|7 6', pack_table(tab)) + + Update(tab, make_block([ '=', 'b', 'a' ])).run() + with self.subTest(msg='test update query'): + eq('a b|3 3|5 5|7 7', pack_table(tab)) + + with self.subTest(msg='test update query none'): + with self.assertRaisesRegex(ValueError, 'none'): + # TODO need to do something with it!!! + # run of block below must be terminated with exception + # because there is no 'none' field/variable + # but in prog module there is an error: + # it isn't fails when variable is undefined + # raise of ValueError was removed in + # https://github.com/klenin/EGEpy/commit/d5d9b0ec9ff4224f62d00708fc865bee5741cf5b + Update(tab, make_block(['=', 'b', 'none'])).run() + +if __name__ == '__main__': + unittest.main(verbosity=1) \ No newline at end of file