Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pyperry.egg-info/
html/
*.pkl
*.log
virtualenv/
22 changes: 15 additions & 7 deletions pyperry/adapter/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class RestfulHttpAdapter(AbstractAdapter):
'foo'})} with C{params_wrapper='mod'}, the data encoded for the HTTP
request will include C{mod[id]=4&mod[name]=foo}

- B{default_params}: a python dict of additional paramters to include
alongside the models attributes in any write or delete request.
- B{default_params}: a python dict of additional parameters to build
into the query string.

These parameters will not be wrapped in the C{params_wrapper} if that
option is also present. One thing C{default_params} are useful for
Expand All @@ -69,7 +69,6 @@ def read(self, **kwargs):
if query_string is not None:
url += query_string


http_response, body = self.http_request('GET', url, {}, **kwargs)
response = self.response(http_response, body)
records = response.parsed()
Expand All @@ -92,6 +91,15 @@ def delete(self, **kwargs):
def persistence_request(self, http_method, **kwargs):
model = kwargs['model']
url = self.url_for(http_method, model)

if 'default_params' in self.config.keys():
query = self.restful_params(self.config['default_params'])
query_string = None
if len(query) > 0:
query_string = '?' + urllib.urlencode(query)
if query_string is not None:
url += query_string

params = self.restful_params(self.params_for(model))
http_response, body = self.http_request(http_method, url, params)
return self.response(http_response, body)
Expand Down Expand Up @@ -153,9 +161,6 @@ def params_for(self, model):
"""Builds and encodes a parameters dict for the request"""
params = {}

if 'default_params' in self.config.keys():
params.update(self.config['default_params'])

if 'params_wrapper' in self.config.keys():
params.update({self.config['params_wrapper']: model.fields})
else:
Expand All @@ -170,7 +175,10 @@ def query_string_for(self, relation):
and returns None if there are no parameters.

"""
query = relation.query()
query = {}
if 'default_params' in self.config.keys():
query.update(self.config['default_params'])
query.update(relation.query())
mods = relation.modifiers_value()
if 'query' in mods:
query.update(mods['query'])
Expand Down
14 changes: 8 additions & 6 deletions pyperry/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def source_klass(self, obj=None):
poly_type = None
if (self.options.has_key('polymorphic') and
self.options['polymorphic'] and obj):
if type(obj.__class__) in [pyperry.base.Base, pyperry.base.BaseMeta]:
if isinstance(obj, pyperry.base.Base):
poly_type = getattr(obj, '%s_type' % self.id)
else:
poly_type = obj
Expand Down Expand Up @@ -198,14 +198,16 @@ def source_klass(self, obj=None):
return self._get_resolved_class(type_string)

def _get_resolved_class(self, string):
class_name = self.target_klass.resolve_name(string)
if not class_name:
class_names = self.target_klass.resolve_name(string)
# remove duplicate entries created by module reloading
unique_names = list(set(class_names))
if not unique_names:
raise errors.ModelNotDefined, 'Model %s is not defined.' % (string)
elif len(class_name) > 1:
elif len(unique_names) > 1:
raise errors.AmbiguousClassName, ('Class name %s is'
' ambiguous. Use the namespace option to get your'
' specific class. Got classes %s' % (string, str(class_name)))
return class_name[0]
' specific class. Got classes %s' % (string, str(unique_names)))
return unique_names[0]

def _base_scope(self, obj):
return self.source_klass(obj).scoped().apply_finder_options(
Expand Down
102 changes: 67 additions & 35 deletions tests/restful_http_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,38 +142,6 @@ def test_with_wrapper(self):
params = adapter.params_for(self.model)
self.assertEqual(params, {'widget': self.model.fields})

def test_with_default_params(self):
"""should include the default_options with the attribuets"""
self.config['default_params'] = {'foo': 'bar'}
expected = copy(self.model.fields)
expected.update({'foo':'bar'})

adapter = RestfulHttpAdapter(self.config)
params = adapter.params_for(self.model)
self.assertEqual(params, expected)

def test_with_default_params_and_params_wrapper(self):
"""should include the attributes inside the wrapper and the default
params outside the wrapper"""
self.config['default_params'] = {'foo': 'bar', 'widget': 5}
self.config['params_wrapper'] = 'widget'
expected = copy(self.config['default_params'])
expected.update(copy({'widget':self.model.fields}))

adapter = RestfulHttpAdapter(self.config)
params = adapter.params_for(self.model)
self.assertEqual(params, expected)

def test_dont_modify_default_params(self):
"""should not modify the default_params when building params"""
self.config['default_params'] = {'foo':'bar'}
expected = copy(self.config['default_params'])

adapter = RestfulHttpAdapter(self.config)
params = adapter.params_for(self.model)
self.assertEqual(adapter.config['default_params'], expected)


class ReadTestCase(HttpAdapterTestCase):

def setUp(self):
Expand Down Expand Up @@ -237,6 +205,34 @@ def test_modifiers_and_relation(self):

self.assertEqual(query, expected)

def test_default_params_in_query(self):
"""should include the default_params in the query string"""
r = pyperry.Base.scoped().where({'id': 3}).limit(1)
self.config['default_params'] = {'foo': 'bar'}
adapter = RestfulHttpAdapter(self.config)
adapter.read(relation=r)

expected = 'where[][id]=3&limit=1&foo=bar'
expected = expected.replace('[', '%5B').replace(']', '%5D')
expected = expected.split('&')
expected.sort()

last_request = http_server.last_request()
query = last_request['path'].split('?')[1]
query = query.split('&')
query.sort()

self.assertEqual(query, expected)

def test_dont_modify_default_params(self):
"""should not modify the default_params when building a query string"""
r = pyperry.Base.scoped().where({'id': 3}).limit(1)
self.config['default_params'] = {'foo': 'bar'}
expected = copy(self.config['default_params'])
adapter = RestfulHttpAdapter(self.config)
adapter.read(relation=r)

self.assertEqual(adapter.config['default_params'], expected)

def test_records(self):
"""should return a list of records retrieved from the response"""
Expand Down Expand Up @@ -265,6 +261,10 @@ class PersistenceTestCase(HttpAdapterTestCase):
def respond_with_success(self, **kwargs):
http_server.set_response(**kwargs)

def adapter_method(self, **kwargs):
method = getattr(self.adapter, self.adapter_method_name)
return method(**kwargs)

def respond_with_failure(self, **kwargs):
error_kwargs = {'status': 500, 'body': 'ERROR'}
error_kwargs.update(kwargs)
Expand Down Expand Up @@ -323,6 +323,38 @@ def test_request(self):
self.assertEqual(last_request['headers']['content-type'],
'application/x-www-form-urlencoded')

def test_default_params_in_query(self):
"""should include the default_params in the query string"""
if type(self) is PersistenceTestCase: return
self.respond_with_success()
self.config['default_params'] = {'foo': 'bar'}
self.adapter = RestfulHttpAdapter(self.config)

response = self.adapter_method(model=self.model)

expected = 'foo=bar'
expected = expected.split('&')
expected.sort()

last_request = http_server.last_request()
query = last_request['path'].split('?')[1]
query = query.split('&')
query.sort()

self.assertEqual(query, expected)

def test_dont_modify_default_params(self):
"""should not modify the default_params when building a query string"""
if type(self) is PersistenceTestCase: return
self.respond_with_success()
self.config['default_params'] = {'foo': 'bar'}
expected = copy(self.config['default_params'])
self.adapter = RestfulHttpAdapter(self.config)

response = self.adapter_method(model=self.model)

self.assertEqual(self.adapter.config['default_params'], expected)


class CreateTestCase(PersistenceTestCase):
"""Run tests from PersistenceTestCase configured for creating a record"""
Expand All @@ -331,7 +363,7 @@ def setUp(self):
super(CreateTestCase, self).setUp()
self.model.new_record = True
self.http_method = 'POST'
self.adapter_method = self.adapter.write
self.adapter_method_name = 'write'
print "\n\tCreateTestCase" # Will display if test fails, so we know
# which test case the fail was from.

Expand All @@ -343,7 +375,7 @@ def setUp(self):
super(UpdateTestCase, self).setUp()
self.model.new_record = False
self.http_method = 'PUT'
self.adapter_method = self.adapter.write
self.adapter_method_name = 'write'
print "\n\tUpdateTestCase"


Expand All @@ -354,7 +386,7 @@ def setUp(self):
super(DeleteTestCase, self).setUp()
self.model.new_record = False
self.http_method = 'DELETE'
self.adapter_method = self.adapter.delete
self.adapter_method_name = 'delete'
print "\n\tDeleteTestCase"


Expand Down