diff --git a/architect/orms/django/features.py b/architect/orms/django/features.py index 75ee77b..e5c06d1 100644 --- a/architect/orms/django/features.py +++ b/architect/orms/django/features.py @@ -1,12 +1,12 @@ """ Defines features for the Django ORM. """ +from contextlib import closing from django.conf import settings from django.db import router, connections, transaction from django.db.models.fields import FieldDoesNotExist from django.db.utils import ConnectionDoesNotExist -from django.utils.functional import cached_property from ..bases import BasePartitionFeature, BaseOperationFeature from ...exceptions import PartitionColumnError, OptionNotSetError, OptionValueError @@ -20,7 +20,7 @@ class ConnectionMixin(object): def database(self): return self.options.get('db', router.db_for_write(self.model_cls)) - @cached_property + @property def connection(self): db = self.database @@ -31,9 +31,12 @@ def connection(self): class OperationFeature(ConnectionMixin, BaseOperationFeature): - def execute(self, sql, autocommit=True): + def execute(self, sql, autocommit=True, connection=None): + if connection is None: + connection = self.connection + if not autocommit: - return self.connection.execute(sql) + return connection.execute(sql) try: autocommit = transaction.atomic # Django >= 1.6 @@ -45,22 +48,23 @@ def execute(self, sql, autocommit=True): autocommit = transaction.commit_on_success # Django <= 1.5 with autocommit(using=self.database): - return self.connection.execute(sql) + return connection.execute(sql) def select_one(self, sql): - self.execute(sql) - result = self.connection.fetchone() - return result[0] if result is not None else result + with closing(self.connection) as connection: + self.execute(sql, connection=connection) + result = connection.fetchone() + return result[0] if result is not None else result def select_all(self, sql, as_dict=False): - self.execute(sql) - - if as_dict: - result = [dict(zip([c[0] for c in self.connection.description], row)) for row in self.connection.fetchall()] - else: - result = self.connection.fetchall() + with closing(self.connection) as connection: + self.execute(sql, connection=connection) + if as_dict: + result = [dict(zip([c[0] for c in connection.description], row)) for row in connection.fetchall()] + else: + result = connection.fetchall() - return result + return result class PartitionFeature(ConnectionMixin, BasePartitionFeature):