diff --git a/.gitignore b/.gitignore index d2d6f36..dd5d326 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ nosetests.xml .mr.developer.cfg .project .pydevproject + +.idea \ No newline at end of file diff --git a/alchmanager.py b/alchmanager.py index 653d1be..31a2ccb 100644 --- a/alchmanager.py +++ b/alchmanager.py @@ -3,7 +3,10 @@ """ __version__ = '0.0.2' __author__ = 'Roman Gladkov' + import types +import inspect +from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.ext.declarative.api import DeclarativeMeta @@ -11,21 +14,58 @@ __all__ = ['ManagedQuery', 'ManagedSession'] not_doubleunder = lambda name: not name.startswith('__') +not_under = lambda name: not name.startswith('_') class ManagedQuery(Query): """Managed Query object""" def __init__(self, entities, *args, **kwargs): - entity = entities[0] - if isinstance(entity, DeclarativeMeta): + self.binds = {} + entity = None + + if isinstance(entities, Mapper): + entity = entities.entity + if isinstance(entity, DeclarativeMeta): + if hasattr(entity, '__manager__'): + manager_cls = entity.__manager__ + for fname in filter(not_doubleunder, dir(manager_cls)): + fn = getattr(manager_cls, fname) + setattr(self, fname, types.MethodType(fn, self)) + + if isinstance(entities, tuple) and len(entities): + entity = entities[0] + + if entity and isinstance(entity, DeclarativeMeta): if hasattr(entity, '__manager__'): manager_cls = entity.__manager__ for fname in filter(not_doubleunder, dir(manager_cls)): fn = getattr(manager_cls, fname) - setattr(self, fname, types.MethodType(fn, self)) + + self.binds.update({fname: fn}) + self.__rebind() + super(ManagedQuery, self).__init__(entities, *args, **kwargs) + def __getattribute__(self, name): + """ Rebind function each function call + + :param name: str + :return: any + """ + returned = object.__getattribute__(self, name) + + if name != '_ManagedQuery__rebind' and \ + (inspect.isfunction(returned) or inspect.ismethod(returned)): + # print('called ', returned.__name__) + self.__rebind() + return returned + + def __rebind(self): + if len(self.binds): + for fname, fn in self.binds.items(): + setattr(self, fname, types.MethodType(fn, self)) + class ManagedSession(Session): diff --git a/tests_flask.py b/tests_flask.py new file mode 100644 index 0000000..c3539b8 --- /dev/null +++ b/tests_flask.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ~ Author: Pavel Ivanov + +import flask +import unittest +import sqlalchemy as sa +from flask_sqlalchemy import SQLAlchemy + +from alchmanager import ManagedQuery + + +class Config: + DEBUG = False + TESTING = True + + SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' + SQLALCHEMY_TRACK_MODIFICATIONS = True + + +app = flask.Flask(__name__) +db = SQLAlchemy(query_class=ManagedQuery) + + +class MainManager: + + @staticmethod + def is_index(query): + return query.filter_by(is_index=True) + + @staticmethod + def is_public(query): + return query.filter_by(is_public=True) + + +class Main(db.Model): + __tablename__ = 'main' + id = sa.Column(sa.Integer, primary_key=True) + child = sa.Column(sa.Integer, index=True) + preview = sa.Column(sa.String(50)) + typeMedia = sa.Column(sa.Integer) + is_index = sa.Column(sa.Boolean, default=False) + is_public = sa.Column(sa.Boolean, default=False) + + __manager__ = MainManager + + __mapper_args__ = {'polymorphic_on': typeMedia} + + +class Video(Main): + __tablename__ = 'video' + videoid = sa.Column(sa.Integer, sa.ForeignKey(Main.child), primary_key=True) + movie = sa.Column(sa.String(50)) + __mapper_args__ = {'polymorphic_identity': 1, + 'inherit_condition': (Main.typeMedia == 1) & + (Main.child == videoid)} + + +@app.route('/testing-queries-v1', methods=['POST']) +def run_testing_queires_v1(): + standart_query = db.session.query(Video).filter_by(is_index=True).all() + managed_query = db.session.query(Video).is_index().all() + + try: + assert standart_query == managed_query + except AssertionError: + flask.abort(500) + + return '' + + +@app.route('/testing-subclass-query-v1', methods=['POST']) +def run_testing_subclass_query_v1(): + try: + assert hasattr(db.session.query(Video), 'is_index') + assert callable(db.session.query(Video).is_index) + assert callable( + db.session.query(Video).is_index().filter_by(child=1).is_public + ) + except AssertionError: + flask.abort(500) + + return '' + + +@app.route('/testing-queries-v2', methods=['POST']) +def run_testing_queires_v2(): + standart_query = Video.query.filter_by(is_index=True).all() + managed_query = Video.query.is_index().all() + + try: + assert standart_query == managed_query + except AssertionError: + flask.abort(500) + + return '' + + +@app.route('/testing-subclass-query-v2', methods=['POST']) +def run_testing_subclass_query_v2(): + try: + assert hasattr(Video.query, 'is_index') + assert callable(Video.query.is_index) + assert callable(Video.query.is_index().filter_by(child=1).is_public) + except AssertionError: + flask.abort(500) + + return '' + + +class TestsQueryManager(unittest.TestCase): + + def setUp(self): + self.app = None + + app.config.from_object(Config()) + + with app.app_context(): + db.init_app(app) + db.create_all() + + self.app = app.test_client() + + def test_post_v1(self): + response = self.app.post('/testing-queries-v1') + self.assertEqual(response.status_code, 200) + + def test_post_v2(self): + response = self.app.post('/testing-queries-v2') + self.assertEqual(response.status_code, 200) + + def test_post_v3(self): + response = self.app.post('/testing-subclass-query-v1') + self.assertEqual(response.status_code, 200) + + def test_post_v4(self): + response = self.app.post('/testing-subclass-query-v2') + self.assertEqual(response.status_code, 200) + + +if __name__ == "__main__": + unittest.main()