From 065b916e6a93d4ac6855453110132886a79a4de2 Mon Sep 17 00:00:00 2001 From: Pavel Ivanov Date: Sat, 13 Oct 2018 10:05:10 +0300 Subject: [PATCH 1/4] Add Idea files to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index d2d6f36..dd5d326 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ nosetests.xml .mr.developer.cfg .project .pydevproject + +.idea \ No newline at end of file From 1126cb007aa840bb39b7cba3440b202932e60974 Mon Sep 17 00:00:00 2001 From: Pavel Ivanov Date: Sat, 13 Oct 2018 10:09:48 +0300 Subject: [PATCH 2/4] Fix for sqlachemy 1.2.12 --- alchmanager.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/alchmanager.py b/alchmanager.py index 653d1be..e3e9e79 100644 --- a/alchmanager.py +++ b/alchmanager.py @@ -4,6 +4,7 @@ __version__ = '0.0.2' __author__ = 'Roman Gladkov' import types +from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.ext.declarative.api import DeclarativeMeta @@ -17,13 +18,14 @@ class ManagedQuery(Query): """Managed Query object""" def __init__(self, entities, *args, **kwargs): - entity = entities[0] - if isinstance(entity, DeclarativeMeta): - if hasattr(entity, '__manager__'): - manager_cls = entity.__manager__ - for fname in filter(not_doubleunder, dir(manager_cls)): - fn = getattr(manager_cls, fname) - setattr(self, fname, types.MethodType(fn, self)) + if isinstance(entities, Mapper): + entity = entities.entity + if isinstance(entity, DeclarativeMeta): + if hasattr(entity, '__manager__'): + manager_cls = entity.__manager__ + for fname in filter(not_doubleunder, dir(manager_cls)): + fn = getattr(manager_cls, fname) + setattr(self, fname, types.MethodType(fn, self)) super(ManagedQuery, self).__init__(entities, *args, **kwargs) From 55b317c91c3194e69b1e8a0f1699a02fa3a27afc Mon Sep 17 00:00:00 2001 From: Pavel Ivanov Date: Sat, 13 Oct 2018 16:24:15 +0300 Subject: [PATCH 3/4] =?UTF-8?q?=D0=94=D0=BE=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=BE=20=D0=B8=D1=81=D0=BF=D1=80=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=B8=D0=B5=20=D0=B4=D0=BB=D1=8F=20=D0=B4=D0=B2=D1=83?= =?UTF-8?q?=D1=85=20=D0=B2=D0=B0=D1=80=D0=B8=D0=BD=D1=82=D0=BE=D0=B2=20?= =?UTF-8?q?=D0=B8=D1=81=D0=BF=D0=BE=D0=B7=D0=BE=D0=B2=D0=B0=D0=BD=D0=B8?= =?UTF-8?q?=D1=8F:=20-=20Foo.query.baz()=20-=20db.session.query(Foo).baz()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alchmanager.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/alchmanager.py b/alchmanager.py index e3e9e79..a0f2f53 100644 --- a/alchmanager.py +++ b/alchmanager.py @@ -18,14 +18,18 @@ class ManagedQuery(Query): """Managed Query object""" def __init__(self, entities, *args, **kwargs): + entity = None if isinstance(entities, Mapper): entity = entities.entity - if isinstance(entity, DeclarativeMeta): - if hasattr(entity, '__manager__'): - manager_cls = entity.__manager__ - for fname in filter(not_doubleunder, dir(manager_cls)): - fn = getattr(manager_cls, fname) - setattr(self, fname, types.MethodType(fn, self)) + if isinstance(entities, tuple) and len(entities): + entity = entities[0] + + if entity and isinstance(entity, DeclarativeMeta): + if hasattr(entity, '__manager__'): + manager_cls = entity.__manager__ + for fname in filter(not_doubleunder, dir(manager_cls)): + fn = getattr(manager_cls, fname) + setattr(self, fname, types.MethodType(fn, self)) super(ManagedQuery, self).__init__(entities, *args, **kwargs) From 17fa82a4385a2c3d52376d4ce5878eb6e6435913 Mon Sep 17 00:00:00 2001 From: Pavel Ivanov Date: Sat, 13 Oct 2018 23:39:19 +0300 Subject: [PATCH 4/4] =?UTF-8?q?=D0=94=D0=BE=D0=BF=D0=B8=D1=81=D0=B0=D0=BB?= =?UTF-8?q?=20=D0=BA=D0=BB=D0=B0=D1=81=D1=81,=20=D0=B4=D0=BE=D0=B1=D0=B0?= =?UTF-8?q?=D0=B2=D0=B8=D0=BB=20=D1=82=D0=B5=D1=81=D1=82=D1=8B=20=D0=BD?= =?UTF-8?q?=D0=B0=20=D1=84=D0=BB=D0=B0=D1=81=D0=BA=D0=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alchmanager.py | 36 ++++++++++++- tests_flask.py | 142 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 tests_flask.py diff --git a/alchmanager.py b/alchmanager.py index a0f2f53..31a2ccb 100644 --- a/alchmanager.py +++ b/alchmanager.py @@ -3,7 +3,9 @@ """ __version__ = '0.0.2' __author__ = 'Roman Gladkov' + import types +import inspect from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -12,15 +14,25 @@ __all__ = ['ManagedQuery', 'ManagedSession'] not_doubleunder = lambda name: not name.startswith('__') +not_under = lambda name: not name.startswith('_') class ManagedQuery(Query): """Managed Query object""" def __init__(self, entities, *args, **kwargs): + self.binds = {} entity = None + if isinstance(entities, Mapper): entity = entities.entity + if isinstance(entity, DeclarativeMeta): + if hasattr(entity, '__manager__'): + manager_cls = entity.__manager__ + for fname in filter(not_doubleunder, dir(manager_cls)): + fn = getattr(manager_cls, fname) + setattr(self, fname, types.MethodType(fn, self)) + if isinstance(entities, tuple) and len(entities): entity = entities[0] @@ -29,9 +41,31 @@ def __init__(self, entities, *args, **kwargs): manager_cls = entity.__manager__ for fname in filter(not_doubleunder, dir(manager_cls)): fn = getattr(manager_cls, fname) - setattr(self, fname, types.MethodType(fn, self)) + + self.binds.update({fname: fn}) + self.__rebind() + super(ManagedQuery, self).__init__(entities, *args, **kwargs) + def __getattribute__(self, name): + """ Rebind function each function call + + :param name: str + :return: any + """ + returned = object.__getattribute__(self, name) + + if name != '_ManagedQuery__rebind' and \ + (inspect.isfunction(returned) or inspect.ismethod(returned)): + # print('called ', returned.__name__) + self.__rebind() + return returned + + def __rebind(self): + if len(self.binds): + for fname, fn in self.binds.items(): + setattr(self, fname, types.MethodType(fn, self)) + class ManagedSession(Session): diff --git a/tests_flask.py b/tests_flask.py new file mode 100644 index 0000000..c3539b8 --- /dev/null +++ b/tests_flask.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ~ Author: Pavel Ivanov + +import flask +import unittest +import sqlalchemy as sa +from flask_sqlalchemy import SQLAlchemy + +from alchmanager import ManagedQuery + + +class Config: + DEBUG = False + TESTING = True + + SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' + SQLALCHEMY_TRACK_MODIFICATIONS = True + + +app = flask.Flask(__name__) +db = SQLAlchemy(query_class=ManagedQuery) + + +class MainManager: + + @staticmethod + def is_index(query): + return query.filter_by(is_index=True) + + @staticmethod + def is_public(query): + return query.filter_by(is_public=True) + + +class Main(db.Model): + __tablename__ = 'main' + id = sa.Column(sa.Integer, primary_key=True) + child = sa.Column(sa.Integer, index=True) + preview = sa.Column(sa.String(50)) + typeMedia = sa.Column(sa.Integer) + is_index = sa.Column(sa.Boolean, default=False) + is_public = sa.Column(sa.Boolean, default=False) + + __manager__ = MainManager + + __mapper_args__ = {'polymorphic_on': typeMedia} + + +class Video(Main): + __tablename__ = 'video' + videoid = sa.Column(sa.Integer, sa.ForeignKey(Main.child), primary_key=True) + movie = sa.Column(sa.String(50)) + __mapper_args__ = {'polymorphic_identity': 1, + 'inherit_condition': (Main.typeMedia == 1) & + (Main.child == videoid)} + + +@app.route('/testing-queries-v1', methods=['POST']) +def run_testing_queires_v1(): + standart_query = db.session.query(Video).filter_by(is_index=True).all() + managed_query = db.session.query(Video).is_index().all() + + try: + assert standart_query == managed_query + except AssertionError: + flask.abort(500) + + return '' + + +@app.route('/testing-subclass-query-v1', methods=['POST']) +def run_testing_subclass_query_v1(): + try: + assert hasattr(db.session.query(Video), 'is_index') + assert callable(db.session.query(Video).is_index) + assert callable( + db.session.query(Video).is_index().filter_by(child=1).is_public + ) + except AssertionError: + flask.abort(500) + + return '' + + +@app.route('/testing-queries-v2', methods=['POST']) +def run_testing_queires_v2(): + standart_query = Video.query.filter_by(is_index=True).all() + managed_query = Video.query.is_index().all() + + try: + assert standart_query == managed_query + except AssertionError: + flask.abort(500) + + return '' + + +@app.route('/testing-subclass-query-v2', methods=['POST']) +def run_testing_subclass_query_v2(): + try: + assert hasattr(Video.query, 'is_index') + assert callable(Video.query.is_index) + assert callable(Video.query.is_index().filter_by(child=1).is_public) + except AssertionError: + flask.abort(500) + + return '' + + +class TestsQueryManager(unittest.TestCase): + + def setUp(self): + self.app = None + + app.config.from_object(Config()) + + with app.app_context(): + db.init_app(app) + db.create_all() + + self.app = app.test_client() + + def test_post_v1(self): + response = self.app.post('/testing-queries-v1') + self.assertEqual(response.status_code, 200) + + def test_post_v2(self): + response = self.app.post('/testing-queries-v2') + self.assertEqual(response.status_code, 200) + + def test_post_v3(self): + response = self.app.post('/testing-subclass-query-v1') + self.assertEqual(response.status_code, 200) + + def test_post_v4(self): + response = self.app.post('/testing-subclass-query-v2') + self.assertEqual(response.status_code, 200) + + +if __name__ == "__main__": + unittest.main()