diff --git a/.gitignore b/.gitignore index b8efaf8..f463706 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ pyperry.egg-info/ html/ *.pkl *.log +virtualenv/ diff --git a/pyperry/adapter/http.py b/pyperry/adapter/http.py index dd164e2..c25c85d 100644 --- a/pyperry/adapter/http.py +++ b/pyperry/adapter/http.py @@ -42,8 +42,8 @@ class RestfulHttpAdapter(AbstractAdapter): 'foo'})} with C{params_wrapper='mod'}, the data encoded for the HTTP request will include C{mod[id]=4&mod[name]=foo} - - B{default_params}: a python dict of additional paramters to include - alongside the models attributes in any write or delete request. + - B{default_params}: a python dict of additional parameters to build + into the query string. These parameters will not be wrapped in the C{params_wrapper} if that option is also present. One thing C{default_params} are useful for @@ -69,7 +69,6 @@ def read(self, **kwargs): if query_string is not None: url += query_string - http_response, body = self.http_request('GET', url, {}, **kwargs) response = self.response(http_response, body) records = response.parsed() @@ -92,6 +91,15 @@ def delete(self, **kwargs): def persistence_request(self, http_method, **kwargs): model = kwargs['model'] url = self.url_for(http_method, model) + + if 'default_params' in self.config.keys(): + query = self.restful_params(self.config['default_params']) + query_string = None + if len(query) > 0: + query_string = '?' + urllib.urlencode(query) + if query_string is not None: + url += query_string + params = self.restful_params(self.params_for(model)) http_response, body = self.http_request(http_method, url, params) return self.response(http_response, body) @@ -153,9 +161,6 @@ def params_for(self, model): """Builds and encodes a parameters dict for the request""" params = {} - if 'default_params' in self.config.keys(): - params.update(self.config['default_params']) - if 'params_wrapper' in self.config.keys(): params.update({self.config['params_wrapper']: model.fields}) else: @@ -170,7 +175,10 @@ def query_string_for(self, relation): and returns None if there are no parameters. """ - query = relation.query() + query = {} + if 'default_params' in self.config.keys(): + query.update(self.config['default_params']) + query.update(relation.query()) mods = relation.modifiers_value() if 'query' in mods: query.update(mods['query']) diff --git a/pyperry/association.py b/pyperry/association.py index bfda81d..900000d 100644 --- a/pyperry/association.py +++ b/pyperry/association.py @@ -163,7 +163,7 @@ def source_klass(self, obj=None): poly_type = None if (self.options.has_key('polymorphic') and self.options['polymorphic'] and obj): - if type(obj.__class__) in [pyperry.base.Base, pyperry.base.BaseMeta]: + if isinstance(obj, pyperry.base.Base): poly_type = getattr(obj, '%s_type' % self.id) else: poly_type = obj @@ -198,14 +198,16 @@ def source_klass(self, obj=None): return self._get_resolved_class(type_string) def _get_resolved_class(self, string): - class_name = self.target_klass.resolve_name(string) - if not class_name: + class_names = self.target_klass.resolve_name(string) + # remove duplicate entries created by module reloading + unique_names = list(set(class_names)) + if not unique_names: raise errors.ModelNotDefined, 'Model %s is not defined.' % (string) - elif len(class_name) > 1: + elif len(unique_names) > 1: raise errors.AmbiguousClassName, ('Class name %s is' ' ambiguous. Use the namespace option to get your' - ' specific class. Got classes %s' % (string, str(class_name))) - return class_name[0] + ' specific class. Got classes %s' % (string, str(unique_names))) + return unique_names[0] def _base_scope(self, obj): return self.source_klass(obj).scoped().apply_finder_options( diff --git a/tests/restful_http_adapter_test.py b/tests/restful_http_adapter_test.py index 119e18d..5579236 100644 --- a/tests/restful_http_adapter_test.py +++ b/tests/restful_http_adapter_test.py @@ -142,38 +142,6 @@ def test_with_wrapper(self): params = adapter.params_for(self.model) self.assertEqual(params, {'widget': self.model.fields}) - def test_with_default_params(self): - """should include the default_options with the attribuets""" - self.config['default_params'] = {'foo': 'bar'} - expected = copy(self.model.fields) - expected.update({'foo':'bar'}) - - adapter = RestfulHttpAdapter(self.config) - params = adapter.params_for(self.model) - self.assertEqual(params, expected) - - def test_with_default_params_and_params_wrapper(self): - """should include the attributes inside the wrapper and the default - params outside the wrapper""" - self.config['default_params'] = {'foo': 'bar', 'widget': 5} - self.config['params_wrapper'] = 'widget' - expected = copy(self.config['default_params']) - expected.update(copy({'widget':self.model.fields})) - - adapter = RestfulHttpAdapter(self.config) - params = adapter.params_for(self.model) - self.assertEqual(params, expected) - - def test_dont_modify_default_params(self): - """should not modify the default_params when building params""" - self.config['default_params'] = {'foo':'bar'} - expected = copy(self.config['default_params']) - - adapter = RestfulHttpAdapter(self.config) - params = adapter.params_for(self.model) - self.assertEqual(adapter.config['default_params'], expected) - - class ReadTestCase(HttpAdapterTestCase): def setUp(self): @@ -237,6 +205,34 @@ def test_modifiers_and_relation(self): self.assertEqual(query, expected) + def test_default_params_in_query(self): + """should include the default_params in the query string""" + r = pyperry.Base.scoped().where({'id': 3}).limit(1) + self.config['default_params'] = {'foo': 'bar'} + adapter = RestfulHttpAdapter(self.config) + adapter.read(relation=r) + + expected = 'where[][id]=3&limit=1&foo=bar' + expected = expected.replace('[', '%5B').replace(']', '%5D') + expected = expected.split('&') + expected.sort() + + last_request = http_server.last_request() + query = last_request['path'].split('?')[1] + query = query.split('&') + query.sort() + + self.assertEqual(query, expected) + + def test_dont_modify_default_params(self): + """should not modify the default_params when building a query string""" + r = pyperry.Base.scoped().where({'id': 3}).limit(1) + self.config['default_params'] = {'foo': 'bar'} + expected = copy(self.config['default_params']) + adapter = RestfulHttpAdapter(self.config) + adapter.read(relation=r) + + self.assertEqual(adapter.config['default_params'], expected) def test_records(self): """should return a list of records retrieved from the response""" @@ -265,6 +261,10 @@ class PersistenceTestCase(HttpAdapterTestCase): def respond_with_success(self, **kwargs): http_server.set_response(**kwargs) + def adapter_method(self, **kwargs): + method = getattr(self.adapter, self.adapter_method_name) + return method(**kwargs) + def respond_with_failure(self, **kwargs): error_kwargs = {'status': 500, 'body': 'ERROR'} error_kwargs.update(kwargs) @@ -323,6 +323,38 @@ def test_request(self): self.assertEqual(last_request['headers']['content-type'], 'application/x-www-form-urlencoded') + def test_default_params_in_query(self): + """should include the default_params in the query string""" + if type(self) is PersistenceTestCase: return + self.respond_with_success() + self.config['default_params'] = {'foo': 'bar'} + self.adapter = RestfulHttpAdapter(self.config) + + response = self.adapter_method(model=self.model) + + expected = 'foo=bar' + expected = expected.split('&') + expected.sort() + + last_request = http_server.last_request() + query = last_request['path'].split('?')[1] + query = query.split('&') + query.sort() + + self.assertEqual(query, expected) + + def test_dont_modify_default_params(self): + """should not modify the default_params when building a query string""" + if type(self) is PersistenceTestCase: return + self.respond_with_success() + self.config['default_params'] = {'foo': 'bar'} + expected = copy(self.config['default_params']) + self.adapter = RestfulHttpAdapter(self.config) + + response = self.adapter_method(model=self.model) + + self.assertEqual(self.adapter.config['default_params'], expected) + class CreateTestCase(PersistenceTestCase): """Run tests from PersistenceTestCase configured for creating a record""" @@ -331,7 +363,7 @@ def setUp(self): super(CreateTestCase, self).setUp() self.model.new_record = True self.http_method = 'POST' - self.adapter_method = self.adapter.write + self.adapter_method_name = 'write' print "\n\tCreateTestCase" # Will display if test fails, so we know # which test case the fail was from. @@ -343,7 +375,7 @@ def setUp(self): super(UpdateTestCase, self).setUp() self.model.new_record = False self.http_method = 'PUT' - self.adapter_method = self.adapter.write + self.adapter_method_name = 'write' print "\n\tUpdateTestCase" @@ -354,7 +386,7 @@ def setUp(self): super(DeleteTestCase, self).setUp() self.model.new_record = False self.http_method = 'DELETE' - self.adapter_method = self.adapter.delete + self.adapter_method_name = 'delete' print "\n\tDeleteTestCase"