Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions architect/orms/django/features.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Defines features for the Django ORM.
"""
from contextlib import closing

from django.conf import settings
from django.db import router, connections, transaction
from django.db.models.fields import FieldDoesNotExist
from django.db.utils import ConnectionDoesNotExist
from django.utils.functional import cached_property

from ..bases import BasePartitionFeature, BaseOperationFeature
from ...exceptions import PartitionColumnError, OptionNotSetError, OptionValueError
Expand All @@ -20,7 +20,7 @@ class ConnectionMixin(object):
def database(self):
return self.options.get('db', router.db_for_write(self.model_cls))

@cached_property
@property
def connection(self):
db = self.database

Expand All @@ -31,9 +31,12 @@ def connection(self):


class OperationFeature(ConnectionMixin, BaseOperationFeature):
def execute(self, sql, autocommit=True):
def execute(self, sql, autocommit=True, connection=None):
if connection is None:
connection = self.connection

if not autocommit:
return self.connection.execute(sql)
return connection.execute(sql)

try:
autocommit = transaction.atomic # Django >= 1.6
Expand All @@ -45,22 +48,23 @@ def execute(self, sql, autocommit=True):
autocommit = transaction.commit_on_success # Django <= 1.5

with autocommit(using=self.database):
return self.connection.execute(sql)
return connection.execute(sql)

def select_one(self, sql):
self.execute(sql)
result = self.connection.fetchone()
return result[0] if result is not None else result
with closing(self.connection) as connection:
self.execute(sql, connection=connection)
result = connection.fetchone()
return result[0] if result is not None else result

def select_all(self, sql, as_dict=False):
self.execute(sql)

if as_dict:
result = [dict(zip([c[0] for c in self.connection.description], row)) for row in self.connection.fetchall()]
else:
result = self.connection.fetchall()
with closing(self.connection) as connection:
self.execute(sql, connection=connection)
if as_dict:
result = [dict(zip([c[0] for c in connection.description], row)) for row in connection.fetchall()]
else:
result = connection.fetchall()

return result
return result


class PartitionFeature(ConnectionMixin, BasePartitionFeature):
Expand Down