Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ nosetests.xml
.mr.developer.cfg
.project
.pydevproject

.idea
46 changes: 43 additions & 3 deletions alchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,69 @@
"""
__version__ = '0.0.2'
__author__ = 'Roman Gladkov'

import types
import inspect
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.ext.declarative.api import DeclarativeMeta

__all__ = ['ManagedQuery', 'ManagedSession']

not_doubleunder = lambda name: not name.startswith('__')
not_under = lambda name: not name.startswith('_')


class ManagedQuery(Query):
"""Managed Query object"""

def __init__(self, entities, *args, **kwargs):
entity = entities[0]
if isinstance(entity, DeclarativeMeta):
self.binds = {}
entity = None

if isinstance(entities, Mapper):
entity = entities.entity
if isinstance(entity, DeclarativeMeta):
if hasattr(entity, '__manager__'):
manager_cls = entity.__manager__
for fname in filter(not_doubleunder, dir(manager_cls)):
fn = getattr(manager_cls, fname)
setattr(self, fname, types.MethodType(fn, self))

if isinstance(entities, tuple) and len(entities):
entity = entities[0]

if entity and isinstance(entity, DeclarativeMeta):
if hasattr(entity, '__manager__'):
manager_cls = entity.__manager__
for fname in filter(not_doubleunder, dir(manager_cls)):
fn = getattr(manager_cls, fname)
setattr(self, fname, types.MethodType(fn, self))

self.binds.update({fname: fn})
self.__rebind()

super(ManagedQuery, self).__init__(entities, *args, **kwargs)

def __getattribute__(self, name):
""" Rebind function each function call

:param name: str
:return: any
"""
returned = object.__getattribute__(self, name)

if name != '_ManagedQuery__rebind' and \
(inspect.isfunction(returned) or inspect.ismethod(returned)):
# print('called ', returned.__name__)
self.__rebind()
return returned

def __rebind(self):
if len(self.binds):
for fname, fn in self.binds.items():
setattr(self, fname, types.MethodType(fn, self))


class ManagedSession(Session):

Expand Down
142 changes: 142 additions & 0 deletions tests_flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ~ Author: Pavel Ivanov

import flask
import unittest
import sqlalchemy as sa
from flask_sqlalchemy import SQLAlchemy

from alchmanager import ManagedQuery


class Config:
DEBUG = False
TESTING = True

SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
SQLALCHEMY_TRACK_MODIFICATIONS = True


app = flask.Flask(__name__)
db = SQLAlchemy(query_class=ManagedQuery)


class MainManager:

@staticmethod
def is_index(query):
return query.filter_by(is_index=True)

@staticmethod
def is_public(query):
return query.filter_by(is_public=True)


class Main(db.Model):
__tablename__ = 'main'
id = sa.Column(sa.Integer, primary_key=True)
child = sa.Column(sa.Integer, index=True)
preview = sa.Column(sa.String(50))
typeMedia = sa.Column(sa.Integer)
is_index = sa.Column(sa.Boolean, default=False)
is_public = sa.Column(sa.Boolean, default=False)

__manager__ = MainManager

__mapper_args__ = {'polymorphic_on': typeMedia}


class Video(Main):
__tablename__ = 'video'
videoid = sa.Column(sa.Integer, sa.ForeignKey(Main.child), primary_key=True)
movie = sa.Column(sa.String(50))
__mapper_args__ = {'polymorphic_identity': 1,
'inherit_condition': (Main.typeMedia == 1) &
(Main.child == videoid)}


@app.route('/testing-queries-v1', methods=['POST'])
def run_testing_queires_v1():
standart_query = db.session.query(Video).filter_by(is_index=True).all()
managed_query = db.session.query(Video).is_index().all()

try:
assert standart_query == managed_query
except AssertionError:
flask.abort(500)

return ''


@app.route('/testing-subclass-query-v1', methods=['POST'])
def run_testing_subclass_query_v1():
try:
assert hasattr(db.session.query(Video), 'is_index')
assert callable(db.session.query(Video).is_index)
assert callable(
db.session.query(Video).is_index().filter_by(child=1).is_public
)
except AssertionError:
flask.abort(500)

return ''


@app.route('/testing-queries-v2', methods=['POST'])
def run_testing_queires_v2():
standart_query = Video.query.filter_by(is_index=True).all()
managed_query = Video.query.is_index().all()

try:
assert standart_query == managed_query
except AssertionError:
flask.abort(500)

return ''


@app.route('/testing-subclass-query-v2', methods=['POST'])
def run_testing_subclass_query_v2():
try:
assert hasattr(Video.query, 'is_index')
assert callable(Video.query.is_index)
assert callable(Video.query.is_index().filter_by(child=1).is_public)
except AssertionError:
flask.abort(500)

return ''


class TestsQueryManager(unittest.TestCase):

def setUp(self):
self.app = None

app.config.from_object(Config())

with app.app_context():
db.init_app(app)
db.create_all()

self.app = app.test_client()

def test_post_v1(self):
response = self.app.post('/testing-queries-v1')
self.assertEqual(response.status_code, 200)

def test_post_v2(self):
response = self.app.post('/testing-queries-v2')
self.assertEqual(response.status_code, 200)

def test_post_v3(self):
response = self.app.post('/testing-subclass-query-v1')
self.assertEqual(response.status_code, 200)

def test_post_v4(self):
response = self.app.post('/testing-subclass-query-v2')
self.assertEqual(response.status_code, 200)


if __name__ == "__main__":
unittest.main()