diff --git a/.gitignore b/.gitignore index b2cb5de..3014d47 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,8 @@ celerybeat-schedule .env .venv env/ +env2/ +env3/ venv/ ENV/ env.bak/ diff --git a/pinject/bindings.py b/pinject/bindings.py index f0949f3..3c5dcb3 100644 --- a/pinject/bindings.py +++ b/pinject/bindings.py @@ -14,12 +14,9 @@ """ -import inspect import re +import inspect import threading -import types - -from .third_party import decorator from . import binding_keys from . import decorators @@ -181,8 +178,7 @@ def get_provider_bindings( get_arg_names_from_provider_fn_name=( providing.default_get_arg_names_from_provider_fn_name)): provider_bindings = [] - fns = inspect.getmembers(binding_spec, - lambda x: type(x) == types.MethodType) + fns = inspect.getmembers(binding_spec, lambda x: inspect.ismethod(x)) for _, fn in fns: default_arg_names = get_arg_names_from_provider_fn_name(fn.__name__) fn_bindings = get_provider_fn_bindings(fn, default_arg_names) diff --git a/pinject/decorators.py b/pinject/decorators.py index 9dcb586..ae7882a 100644 --- a/pinject/decorators.py +++ b/pinject/decorators.py @@ -14,18 +14,14 @@ """ -import collections -import inspect - -from .third_party import decorator +import decorator from . import arg_binding_keys -from . import binding_keys +from . import support from . import errors from . import locations from . import scoping - _ARG_BINDING_KEYS_ATTR = '_pinject_arg_binding_keys' _IS_WRAPPER_ATTR = '_pinject_is_wrapper' _NON_INJECTABLE_ARG_NAMES_ATTR = '_pinject_non_injectables' @@ -83,13 +79,12 @@ def inject(arg_names=None, all_except=None): back_frame_loc = locations.get_back_frame_loc() if arg_names is not None and all_except is not None: raise errors.TooManyArgsToInjectDecoratorError(back_frame_loc) - for arg, arg_value in [('arg_names', arg_names), - ('all_except', all_except)]: + for arg, arg_value in [('arg_names', arg_names), ('all_except', all_except)]: if arg_value is not None: if not arg_value: raise errors.EmptySequenceArgError(back_frame_loc, arg) - if (not isinstance(arg_value, collections.Sequence) or - isinstance(arg_value, basestring)): + if (not support.is_sequence(arg_value) or + support.is_string(arg_value)): raise errors.WrongArgTypeError( arg, 'sequence (of arg names)', type(arg_value).__name__) if arg_names is None and all_except is None: @@ -207,6 +202,7 @@ def _get_pinject_decorated_fn(fn): else: def _pinject_decorated_fn(fn_to_wrap, *pargs, **kwargs): return fn_to_wrap(*pargs, **kwargs) + pinject_decorated_fn = decorator.decorator(_pinject_decorated_fn, fn) # TODO(kurts): split this so that __init__() decorators don't get # the provider attribute. @@ -225,28 +221,26 @@ def _get_pinject_wrapper( def get_pinject_decorated_fn_with_additions(fn): pinject_decorated_fn = _get_pinject_decorated_fn(fn) orig_arg_names, unused_varargs, unused_keywords, unused_defaults = ( - inspect.getargspec(getattr(pinject_decorated_fn, _ORIG_FN_ATTR))) + support.get_method_args(getattr(pinject_decorated_fn, _ORIG_FN_ATTR))) if arg_binding_key is not None: - if not arg_binding_key.can_apply_to_one_of_arg_names( - orig_arg_names): - raise errors.NoSuchArgToInjectError( - decorator_loc, arg_binding_key, fn) + if not arg_binding_key.can_apply_to_one_of_arg_names(orig_arg_names): + raise errors.NoSuchArgToInjectError(decorator_loc, arg_binding_key, fn) if arg_binding_key.conflicts_with_any_arg_binding_key( - getattr(pinject_decorated_fn, _ARG_BINDING_KEYS_ATTR)): + getattr(pinject_decorated_fn, _ARG_BINDING_KEYS_ATTR)): raise errors.MultipleAnnotationsForSameArgError( arg_binding_key, decorator_loc) getattr(pinject_decorated_fn, _ARG_BINDING_KEYS_ATTR).append( arg_binding_key) if (provider_arg_name is not None or provider_annotated_with is not None or - provider_in_scope_id is not None): + provider_in_scope_id is not None): provider_decorations = getattr( pinject_decorated_fn, _PROVIDER_DECORATIONS_ATTR) provider_decorations.append(ProviderDecoration( provider_arg_name, provider_annotated_with, provider_in_scope_id)) if (inject_arg_names is not None or - inject_all_except_arg_names is not None): + inject_all_except_arg_names is not None): if hasattr(pinject_decorated_fn, _NON_INJECTABLE_ARG_NAMES_ATTR): raise errors.DuplicateDecoratorError('inject', decorator_loc) non_injectable_arg_names = [] @@ -265,6 +259,7 @@ def get_pinject_decorated_fn_with_additions(fn): if len(non_injectable_arg_names) == len(orig_arg_names): raise errors.NoRemainingArgsToInjectError(decorator_loc) return pinject_decorated_fn + return get_pinject_decorated_fn_with_additions @@ -285,8 +280,7 @@ def get_injectable_arg_binding_keys(fn, direct_pargs, direct_kwargs): existing_arg_binding_keys = [] orig_fn = fn - arg_names, unused_varargs, unused_keywords, defaults = ( - inspect.getargspec(orig_fn)) + arg_names, unused_varargs, unused_keywords, defaults = support.get_method_args(orig_fn) num_args_with_defaults = len(defaults) if defaults is not None else 0 if num_args_with_defaults: arg_names = arg_names[:-num_args_with_defaults] diff --git a/pinject/errors.py b/pinject/errors.py index 6e3c21f..af5d21e 100644 --- a/pinject/errors.py +++ b/pinject/errors.py @@ -14,7 +14,7 @@ """ -import locations +from . import locations class Error(Exception): diff --git a/pinject/finding.py b/pinject/finding.py index ea8679a..2f07603 100644 --- a/pinject/finding.py +++ b/pinject/finding.py @@ -35,11 +35,10 @@ def find_classes(modules, classes): def _get_explicit_or_default_modules(modules): if modules is ALL_IMPORTED_MODULES: - return sys.modules.values() + return list(sys.modules.values()) elif modules is None: return [] - else: - return modules + return modules def _find_classes_in_module(module): diff --git a/pinject/initializers.py b/pinject/initializers.py index dba1be8..812fbd4 100644 --- a/pinject/initializers.py +++ b/pinject/initializers.py @@ -14,11 +14,10 @@ """ -import inspect - -from .third_party import decorator +import decorator from . import errors +from . import support def copy_args_to_internal_fields(fn): @@ -42,14 +41,16 @@ def _copy_args_to_fields(fn, decorator_name, field_prefix): raise errors.DecoratorAppliedToNonInitError( decorator_name, fn) arg_names, varargs, unused_keywords, unused_defaults = ( - inspect.getargspec(fn)) + support.get_method_args(fn)) if varargs is not None: raise errors.PargsDisallowedWhenCopyingArgsError( decorator_name, fn, varargs) + def CopyThenCall(fn_to_wrap, self, *pargs, **kwargs): for index, parg in enumerate(pargs, start=1): setattr(self, field_prefix + arg_names[index], parg) - for kwarg, kwvalue in kwargs.iteritems(): + for kwarg, kwvalue in support.items(kwargs): setattr(self, field_prefix + kwarg, kwvalue) fn_to_wrap(self, *pargs, **kwargs) + return decorator.decorator(CopyThenCall, fn) diff --git a/pinject/locations.py b/pinject/locations.py index 0dbcf9a..e42bb49 100644 --- a/pinject/locations.py +++ b/pinject/locations.py @@ -16,6 +16,8 @@ import inspect +LOCALS_TOKEN = '' + def get_loc(thing): try: @@ -27,12 +29,8 @@ def get_loc(thing): def get_name_and_loc(thing): try: - if hasattr(thing, 'im_class'): - class_name = '{0}.{1}'.format( - thing.im_class.__name__, thing.__name__) - else: - class_name = '{0}.{1}'.format( - inspect.getmodule(thing).__name__, thing.__name__) + type_name = _get_type_name(thing) + class_name = '{0}.{1}'.format(type_name, thing.__name__) except (TypeError, IOError): class_name = '{0}.{1}'.format( inspect.getmodule(thing).__name__, thing.__name__) @@ -47,3 +45,68 @@ def get_back_frame_loc(): back_frame = inspect.currentframe().f_back.f_back return '{0}:{1}'.format(back_frame.f_code.co_filename, back_frame.f_lineno) + + +def _get_type_name(target_thing): + """ + Functions, bound methods and unbound methods change significantly in Python 3. + + For instance: + + class SomeObject(object): + def method(): + pass + + In Python 2: + - Unbound method inspect.ismethod(SomeObject.method) returns True + - Unbound inspect.isfunction(SomeObject.method) returns False + - Unbound hasattr(SomeObject.method, 'im_class') returns True + - Bound method inspect.ismethod(SomeObject().method) returns True + - Bound method inspect.isfunction(SomeObject().method) returns False + - Bound hasattr(SomeObject().method, 'im_class') returns True + + In Python 3: + - Unbound method inspect.ismethod(SomeObject.method) returns False + - Unbound inspect.isfunction(SomeObject.method) returns True + - Unbound hasattr(SomeObject.method, 'im_class') returns False + - Bound method inspect.ismethod(SomeObject().method) returns True + - Bound method inspect.isfunction(SomeObject().method) returns False + - Bound hasattr(SomeObject().method, 'im_class') returns False + + This method tries to consolidate the approach for retrieving the + enclosing type of a bound/unbound method and functions. + """ + thing = target_thing + if hasattr(thing, 'im_class'): + # only works in Python 2 + return thing.im_class.__name__ + if inspect.ismethod(thing): + for cls in inspect.getmro(thing.__self__.__class__): + if cls.__dict__.get(thing.__name__) is thing: + return cls.__name__ + thing = thing.__func__ + if inspect.isfunction(thing) and hasattr(thing, '__qualname__'): + qualifier = thing.__qualname__ + if LOCALS_TOKEN in qualifier: + return _get_local_type_name(thing) + return _get_external_type_name(thing) + return inspect.getmodule(target_thing).__name__ + + +def _get_local_type_name(thing): + qualifier = thing.__qualname__ + parts = qualifier.split(LOCALS_TOKEN, 1) + type_name = parts[1].split('.')[1] + if thing.__name__ == type_name: + return inspect.getmodule(thing).__name__ + return type_name + + +def _get_external_type_name(thing): + qualifier = thing.__qualname__ + name = qualifier.rsplit('.', 1)[0] + if hasattr(inspect.getmodule(thing), name): + cls = getattr(inspect.getmodule(thing), name) + if isinstance(cls, type): + return cls.__name__ + return inspect.getmodule(thing).__name__ diff --git a/pinject/object_graph.py b/pinject/object_graph.py index 6753d83..9d73a27 100644 --- a/pinject/object_graph.py +++ b/pinject/object_graph.py @@ -14,12 +14,6 @@ """ -import collections -import functools -import inspect -import types - -from . import arg_binding_keys from . import bindings from . import decorators from . import errors @@ -30,6 +24,7 @@ from . import providing from . import required_bindings as required_bindings_lib from . import scoping +from . import support def new_object_graph( @@ -81,21 +76,21 @@ def new_object_graph( """ try: if modules is not None and modules is not finding.ALL_IMPORTED_MODULES: - _verify_types(modules, types.ModuleType, 'modules') + support.verify_module_types(modules, 'modules') if classes is not None: - _verify_types(classes, types.TypeType, 'classes') + support.verify_class_types(classes, 'classes') if binding_specs is not None: - _verify_subclasses( + support.verify_subclasses( binding_specs, bindings.BindingSpec, 'binding_specs') if get_arg_names_from_class_name is not None: - _verify_callable(get_arg_names_from_class_name, - 'get_arg_names_from_class_name') + support.verify_callable(get_arg_names_from_class_name, + 'get_arg_names_from_class_name') if get_arg_names_from_provider_fn_name is not None: - _verify_callable(get_arg_names_from_provider_fn_name, - 'get_arg_names_from_provider_fn_name') + support.verify_callable(get_arg_names_from_provider_fn_name, + 'get_arg_names_from_provider_fn_name') if is_scope_usable_from_scope is not None: - _verify_callable(is_scope_usable_from_scope, - 'is_scope_usable_from_scope') + support.verify_callable(is_scope_usable_from_scope, + 'is_scope_usable_from_scope') injection_context_factory = injection_contexts.InjectionContextFactory( is_scope_usable_from_scope) id_to_scope = scoping.get_id_to_scope_with_defaults(id_to_scope) @@ -169,46 +164,10 @@ def new_object_graph( use_short_stack_traces) -def _verify_type(elt, required_type, arg_name): - if type(elt) != required_type: - raise errors.WrongArgTypeError( - arg_name, required_type.__name__, type(elt).__name__) - - -def _verify_types(seq, required_type, arg_name): - if not isinstance(seq, collections.Sequence): - raise errors.WrongArgTypeError( - arg_name, 'sequence (of {0})'.format(required_type.__name__), - type(seq).__name__) - for idx, elt in enumerate(seq): - if type(elt) != required_type: - raise errors.WrongArgElementTypeError( - arg_name, idx, required_type.__name__, type(elt).__name__) - - -def _verify_subclasses(seq, required_superclass, arg_name): - if not isinstance(seq, collections.Sequence): - raise errors.WrongArgTypeError( - arg_name, - 'sequence (of subclasses of {0})'.format( - required_superclass.__name__), - type(seq).__name__) - for idx, elt in enumerate(seq): - if not isinstance(elt, required_superclass): - raise errors.WrongArgElementTypeError( - arg_name, idx, - 'subclass of {0}'.format(required_superclass.__name__), - type(elt).__name__) - - -def _verify_callable(fn, arg_name): - if not callable(fn): - raise errors.WrongArgTypeError(arg_name, 'callable', type(fn).__name__) - - def _pare_to_present_args(kwargs, fn): - arg_names, _, _, _ = inspect.getargspec(fn) - return {arg: value for arg, value in kwargs.iteritems() if arg in arg_names} + arg_names, _, _, _ = support.get_method_args(fn) + return {arg: value + for arg, value in support.items(kwargs) if arg in arg_names} class ObjectGraph(object): @@ -231,7 +190,7 @@ def provide(self, cls): Raises: Error: an instance of cls is not providable """ - _verify_type(cls, types.TypeType, 'cls') + support.verify_class_type(cls, 'cls') if not self._is_injectable_fn(cls): provide_loc = locations.get_back_frame_loc() raise errors.NonExplicitlyBoundClassError(provide_loc, cls) diff --git a/pinject/object_providers.py b/pinject/object_providers.py index e66619d..567b210 100644 --- a/pinject/object_providers.py +++ b/pinject/object_providers.py @@ -14,8 +14,7 @@ """ -import types - +from . import support from . import arg_binding_keys from . import decorators from . import errors @@ -61,7 +60,7 @@ def Provide(*pargs, **kwargs): def provide_class(self, cls, injection_context, direct_init_pargs, direct_init_kwargs): - if type(cls.__init__) is types.MethodType: + if support.is_constructor_defined(cls): init_pargs, init_kwargs = self.get_injection_pargs_kwargs( cls.__init__, injection_context, direct_init_pargs, direct_init_kwargs) diff --git a/pinject/support.py b/pinject/support.py new file mode 100644 index 0000000..973e518 --- /dev/null +++ b/pinject/support.py @@ -0,0 +1,105 @@ +"""Copyright 2013 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +import six +import inspect + +from . import errors + +# To be removed once this fix is included in six +# https://github.com/benjaminp/six/issues/155 +try: + import collections.abc as collections_abc +except ImportError: # python <3.3 + import collections as collections_abc + + +def items(dict_instance): + return six.iteritems(dict_instance) + + +def is_sequence(arg_value): + return isinstance(arg_value, collections_abc.Sequence) + + +def is_string(arg_value): + return isinstance(arg_value, six.string_types) + + +def is_constructor_defined(cls): + if six.PY3: + return inspect.isfunction(cls.__init__) + return inspect.ismethod(cls.__init__) + + +def get_method_args(fn): + if six.PY3: + spec = inspect.getfullargspec(fn) + return spec.args, spec.varargs, spec.varkw, spec.defaults + arg_names, varargs, keywords, defaults = inspect.getargspec(fn) + return arg_names, varargs, keywords, defaults + + +def verify_callable(fn, arg_name): + if not callable(fn): + raise errors.WrongArgTypeError(arg_name, 'callable', type(fn).__name__) + + +def verify_subclasses(seq, required_superclass, arg_name): + if not isinstance(seq, collections_abc.Sequence): + raise errors.WrongArgTypeError( + arg_name, + 'sequence (of subclasses of {0})'.format( + required_superclass.__name__), + type(seq).__name__) + for idx, elt in enumerate(seq): + if not isinstance(elt, required_superclass): + raise errors.WrongArgElementTypeError( + arg_name, idx, + 'subclass of {0}'.format(required_superclass.__name__), + type(elt).__name__) + + +def verify_module_types(modules, arg_name): + _verify_types(inspect.ismodule, modules, arg_name, 'module') + + +def verify_class_types(seq, arg_name): + _verify_types(inspect.isclass, seq, arg_name, 'class') + + +def verify_class_type(elt, arg_name): + _verify_type(inspect.isclass, elt, arg_name, 'class') + + +def _assert_sequence(seq, arg_name, type_name): + if not is_sequence(seq): + raise errors.WrongArgTypeError( + arg_name, 'sequence (of {0})'.format(type_name), type(seq).__name__) + + +def _verify_types(fn_checker, seq, arg_name, type_name): + _assert_sequence(seq, arg_name, type_name) + for idx, elt in enumerate(seq): + if not fn_checker(elt): + raise errors.WrongArgElementTypeError( + arg_name, idx, type_name, type(elt).__name__) + + +def _verify_type(fn_checker, elt, arg_name, type_name): + if not fn_checker(elt): + raise errors.WrongArgTypeError( + arg_name, type_name, type(elt).__name__) diff --git a/pinject/support_test.py b/pinject/support_test.py new file mode 100644 index 0000000..4d78099 --- /dev/null +++ b/pinject/support_test.py @@ -0,0 +1,142 @@ +"""Copyright 2013 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +import unittest +import types +import inspect + +from pinject import support +from pinject import bindings +from pinject import errors + + +class VerifyTypeTest(unittest.TestCase): + + def test_verifies_correct_type_ok(self): + support._verify_type(inspect.ismodule, types, 'unused', 'module') + + def test_raises_exception_if_incorrect_type(self): + self.assertRaises(errors.WrongArgTypeError, support._verify_type, + inspect.ismodule, 'not-a-module', + 'an-arg-name', 'module') + + +class VerifyTypesTest(unittest.TestCase): + + def test_verifies_empty_sequence_ok(self): + def fn_checker(elt): + return True + support._verify_types(fn_checker, [], 'unused', 'no_type') + + def test_verifies_correct_type_ok(self): + support._verify_types(inspect.ismodule, [types], 'unused', 'module') + + def test_raises_exception_if_not_sequence(self): + self.assertRaises(errors.WrongArgTypeError, support._verify_types, + inspect.ismodule, 42, 'an-arg-name', 'module') + + def test_raises_exception_if_element_is_incorrect_type(self): + self.assertRaises(errors.WrongArgElementTypeError, + support._verify_types, inspect.ismodule, + ['not-a-module'], 'an-arg-name', 'module') + + +class VerifySubclassesTest(unittest.TestCase): + + def test_verifies_empty_sequence_ok(self): + support.verify_subclasses([], bindings.BindingSpec, 'unused') + + def test_verifies_correct_type_ok(self): + class SomeBindingSpec(bindings.BindingSpec): + pass + support.verify_subclasses( + [SomeBindingSpec()], bindings.BindingSpec, 'unused') + + def test_raises_exception_if_not_sequence(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_subclasses, + 42, bindings.BindingSpec, 'an-arg-name') + + def test_raises_exception_if_element_is_not_subclass(self): + class NotBindingSpec(object): + pass + self.assertRaises( + errors.WrongArgElementTypeError, support.verify_subclasses, + [NotBindingSpec()], bindings.BindingSpec, 'an-arg-name') + + +class VerifyCallableTest(unittest.TestCase): + + def test_verifies_callable_ok(self): + support.verify_callable(lambda: None, 'unused') + + def test_raises_exception_if_not_callable(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_callable, 42, 'an-arg-name') + + +class VerifyModuleTypesTest(unittest.TestCase): + + def test_verifies_module_types_ok(self): + support.verify_module_types([types], 'unused') + + def test_raises_exception_if_not_module_types(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_module_types, 42, 'an-arg-name') + + +class VerifyClassTypesTest(unittest.TestCase): + + def test_verifies_module_types_ok(self): + class Foo(object): + pass + support.verify_class_types([Foo], 'unused') + + def test_raises_exception_if_not_class_types(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_class_types, 42, 'an-arg-name') + + +class IsSequenceTest(unittest.TestCase): + + def test_argument_identified_as_sequence_instance(self): + self.assertTrue(support.is_sequence(list())) + + def test_argument_identified_as_not_sequence_instance(self): + self.assertFalse(support.is_sequence(None)) + + +class IsStringTest(unittest.TestCase): + + def test_argument_identified_as_string_instance(self): + self.assertTrue(support.is_string('some string')) + + def test_argument_identified_as_not_string_instance(self): + self.assertFalse(support.is_string(None)) + + +class IsConstructorDefinedTest(unittest.TestCase): + + def test_constructor_present_detection(self): + class Foo(object): + def __init__(self): + pass + self.assertTrue(support.is_constructor_defined(Foo)) + + def test_constructor_not_present_detection(self): + class Foo(object): + pass + self.assertFalse(support.is_constructor_defined(Foo)) diff --git a/pinject/third_party/__init__.py b/pinject/third_party/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/pinject/third_party/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/pinject/third_party/decorator.py b/pinject/third_party/decorator.py deleted file mode 100644 index e003914..0000000 --- a/pinject/third_party/decorator.py +++ /dev/null @@ -1,251 +0,0 @@ -########################## LICENCE ############################### - -# Copyright (c) 2005-2012, Michele Simionato -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: - -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in bytecode form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in -# the documentation and/or other materials provided with the -# distribution. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS -# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. - -""" -Decorator module, see http://pypi.python.org/pypi/decorator -for the documentation. -""" - -__version__ = '3.4.0' - -__all__ = ["decorator", "FunctionMaker", "contextmanager"] - -import sys, re, inspect -if sys.version >= '3': - from inspect import getfullargspec - def get_init(cls): - return cls.__init__ -else: - class getfullargspec(object): - "A quick and dirty replacement for getfullargspec for Python 2.X" - def __init__(self, f): - self.args, self.varargs, self.varkw, self.defaults = \ - inspect.getargspec(f) - self.kwonlyargs = [] - self.kwonlydefaults = None - def __iter__(self): - yield self.args - yield self.varargs - yield self.varkw - yield self.defaults - def get_init(cls): - return cls.__init__.im_func - -DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') - -# basic functionality -class FunctionMaker(object): - """ - An object with the ability to create functions with a given signature. - It has attributes name, doc, module, signature, defaults, dict and - methods update and make. - """ - def __init__(self, func=None, name=None, signature=None, - defaults=None, doc=None, module=None, funcdict=None): - self.shortsignature = signature - if func: - # func can be a class or a callable, but not an instance method - self.name = func.__name__ - if self.name == '': # small hack for lambda functions - self.name = '_lambda_' - self.doc = func.__doc__ - self.module = func.__module__ - if inspect.isfunction(func): - argspec = getfullargspec(func) - self.annotations = getattr(func, '__annotations__', {}) - for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', - 'kwonlydefaults'): - setattr(self, a, getattr(argspec, a)) - for i, arg in enumerate(self.args): - setattr(self, 'arg%d' % i, arg) - if sys.version < '3': # easy way - self.shortsignature = self.signature = \ - inspect.formatargspec( - formatvalue=lambda val: "", *argspec)[1:-1] - else: # Python 3 way - allargs = list(self.args) - allshortargs = list(self.args) - if self.varargs: - allargs.append('*' + self.varargs) - allshortargs.append('*' + self.varargs) - elif self.kwonlyargs: - allargs.append('*') # single star syntax - for a in self.kwonlyargs: - allargs.append('%s=None' % a) - allshortargs.append('%s=%s' % (a, a)) - if self.varkw: - allargs.append('**' + self.varkw) - allshortargs.append('**' + self.varkw) - self.signature = ', '.join(allargs) - self.shortsignature = ', '.join(allshortargs) - self.dict = func.__dict__.copy() - # func=None happens when decorating a caller - if name: - self.name = name - if signature is not None: - self.signature = signature - if defaults: - self.defaults = defaults - if doc: - self.doc = doc - if module: - self.module = module - if funcdict: - self.dict = funcdict - # check existence required attributes - assert hasattr(self, 'name') - if not hasattr(self, 'signature'): - raise TypeError('You are decorating a non function: %s' % func) - - def update(self, func, **kw): - "Update the signature of func with the data in self" - func.__name__ = self.name - func.__doc__ = getattr(self, 'doc', None) - func.__dict__ = getattr(self, 'dict', {}) - func.func_defaults = getattr(self, 'defaults', ()) - func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None) - func.__annotations__ = getattr(self, 'annotations', None) - callermodule = sys._getframe(3).f_globals.get('__name__', '?') - func.__module__ = getattr(self, 'module', callermodule) - func.__dict__.update(kw) - - def make(self, src_templ, evaldict=None, addsource=False, **attrs): - "Make a new function from a given template and update the signature" - src = src_templ % vars(self) # expand name and signature - evaldict = evaldict or {} - mo = DEF.match(src) - if mo is None: - raise SyntaxError('not a valid function template\n%s' % src) - name = mo.group(1) # extract the function name - names = set([name] + [arg.strip(' *') for arg in - self.shortsignature.split(',')]) - for n in names: - if n in ('_func_', '_call_'): - raise NameError('%s is overridden in\n%s' % (n, src)) - if not src.endswith('\n'): # add a newline just for safety - src += '\n' # this is needed in old versions of Python - try: - code = compile(src, '', 'single') - # print >> sys.stderr, 'Compiling %s' % src - exec code in evaldict - except: - print >> sys.stderr, 'Error in generated code:' - print >> sys.stderr, src - raise - func = evaldict[name] - if addsource: - attrs['__source__'] = src - self.update(func, **attrs) - return func - - @classmethod - def create(cls, obj, body, evaldict, defaults=None, - doc=None, module=None, addsource=True, **attrs): - """ - Create a function from the strings name, signature and body. - evaldict is the evaluation dictionary. If addsource is true an attribute - __source__ is added to the result. The attributes attrs are added, - if any. - """ - if isinstance(obj, str): # "name(signature)" - name, rest = obj.strip().split('(', 1) - signature = rest[:-1] #strip a right parens - func = None - else: # a function - name = None - signature = None - func = obj - self = cls(func, name, signature, defaults, doc, module) - ibody = '\n'.join(' ' + line for line in body.splitlines()) - return self.make('def %(name)s(%(signature)s):\n' + ibody, - evaldict, addsource, **attrs) - -def decorator(caller, func=None): - """ - decorator(caller) converts a caller function into a decorator; - decorator(caller, func) decorates a function using a caller. - """ - if func is not None: # returns a decorated function - evaldict = func.func_globals.copy() - evaldict['_call_'] = caller - evaldict['_func_'] = func - return FunctionMaker.create( - func, "return _call_(_func_, %(shortsignature)s)", - evaldict, undecorated=func, __wrapped__=func) - else: # returns a decorator - if inspect.isclass(caller): - name = caller.__name__.lower() - callerfunc = get_init(caller) - doc = 'decorator(%s) converts functions/generators into ' \ - 'factories of %s objects' % (caller.__name__, caller.__name__) - fun = getfullargspec(callerfunc).args[1] # second arg - elif inspect.isfunction(caller): - name = '_lambda_' if caller.__name__ == '' \ - else caller.__name__ - callerfunc = caller - doc = caller.__doc__ - fun = getfullargspec(callerfunc).args[0] # first arg - else: # assume caller is an object with a __call__ method - name = caller.__class__.__name__.lower() - callerfunc = caller.__call__.im_func - doc = caller.__call__.__doc__ - fun = getfullargspec(callerfunc).args[1] # second arg - evaldict = callerfunc.func_globals.copy() - evaldict['_call_'] = caller - evaldict['decorator'] = decorator - return FunctionMaker.create( - '%s(%s)' % (name, fun), - 'return decorator(_call_, %s)' % fun, - evaldict, undecorated=caller, __wrapped__=caller, - doc=doc, module=caller.__module__) - -######################### contextmanager ######################## - -def __call__(self, func): - 'Context manager decorator' - return FunctionMaker.create( - func, "with _self_: return _func_(%(shortsignature)s)", - dict(_self_=self, _func_=func), __wrapped__=func) - -try: # Python >= 3.2 - - from contextlib import _GeneratorContextManager - ContextManager = type( - 'ContextManager', (_GeneratorContextManager,), dict(__call__=__call__)) - -except ImportError: # Python >= 2.5 - - from contextlib import GeneratorContextManager - def __init__(self, f, *a, **k): - return GeneratorContextManager.__init__(self, f(*a, **k)) - ContextManager = type( - 'ContextManager', (GeneratorContextManager,), - dict(__call__=__call__, __init__=__init__)) - -contextmanager = decorator(ContextManager) diff --git a/requirements_dev.txt b/requirements_dev.txt index d9e72ef..3a64df7 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,5 +1,6 @@ # dev deps -nose>=1.3.7 +pytest>=3.8.0 # prod deps six>=1.7.3 +decorator>=4.3.0 diff --git a/setup.py b/setup.py index c6cc2bc..3358fc4 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,7 @@ # limitations under the License. -from distutils.core import setup - +from setuptools import setup setup(name='pinject', version='0.11.0', @@ -27,4 +26,5 @@ license='Apache License 2.0', long_description=open('README.rst').read(), platforms='all', - packages=['pinject', 'pinject/third_party']) + packages=['pinject', 'pinject/third_party'], + install_requires=['six>=1.7.3', 'decorator>=4.3.0']) diff --git a/tests/arg_binding_keys_test.py b/tests/arg_binding_keys_test.py index b898f99..42c855f 100644 --- a/tests/arg_binding_keys_test.py +++ b/tests/arg_binding_keys_test.py @@ -16,7 +16,6 @@ import unittest -from pinject import annotations from pinject import arg_binding_keys from pinject import binding_keys from pinject import provider_indirections diff --git a/tests/decorators_test.py b/tests/decorators_test.py index dc9ac9e..1d3bdf8 100644 --- a/tests/decorators_test.py +++ b/tests/decorators_test.py @@ -14,7 +14,6 @@ """ -import inspect import unittest from pinject import arg_binding_keys @@ -22,8 +21,8 @@ from pinject import binding_keys from pinject import decorators from pinject import errors -from pinject import injection_contexts from pinject import scoping +from pinject import support # TODO(kurts): have only one FakeObjectProvider for tests. @@ -288,8 +287,7 @@ def test_can_introspect_wrapped_fn(self): @decorators.annotate_arg('foo', 'an-annotation') def some_function(foo, bar='BAR', *pargs, **kwargs): pass - arg_names, varargs, keywords, defaults = inspect.getargspec( - some_function) + arg_names, varargs, keywords, defaults = support.get_method_args(some_function) self.assertEqual(['foo', 'bar'], arg_names) self.assertEqual('pargs', varargs) self.assertEqual('kwargs', keywords) diff --git a/tests/initializers_test.py b/tests/initializers_test.py index 418b192..55c6e35 100644 --- a/tests/initializers_test.py +++ b/tests/initializers_test.py @@ -14,11 +14,11 @@ """ -import inspect import unittest from pinject import errors from pinject import initializers +from pinject import support class CopyArgsToInternalFieldsTest(unittest.TestCase): @@ -65,7 +65,7 @@ def __init__(self, foo): pass self.assertEqual('__init__', SomeClass.__init__.__name__) arg_names, unused_varargs, unused_keywords, unused_defaults = ( - inspect.getargspec(SomeClass.__init__)) + support.get_method_args(SomeClass.__init__)) self.assertEqual(['self', 'foo'], arg_names) def test_raises_exception_if_init_takes_pargs(self): diff --git a/tests/locations_test.py b/tests/locations_test.py index cbc2582..bd7c603 100644 --- a/tests/locations_test.py +++ b/tests/locations_test.py @@ -19,6 +19,11 @@ from pinject import locations +class ExternalObject(object): + def a_method(self): + pass + + class GetTypeLocTest(unittest.TestCase): def test_known(self): @@ -40,7 +45,12 @@ class OtherObject(object): self.assertIn('OtherObject', class_name_and_loc) self.assertIn('locations_test.py', class_name_and_loc) - def test_known_as_part_of_class(self): + def test_known_external(self): + class_name_and_loc = locations.get_name_and_loc(ExternalObject) + self.assertIn('ExternalObject', class_name_and_loc) + self.assertIn('locations_test.py', class_name_and_loc) + + def test_known_unbound_method_as_part_of_class(self): class OtherObject(object): def a_method(self): pass @@ -48,12 +58,56 @@ def a_method(self): self.assertIn('OtherObject.a_method', class_name_and_loc) self.assertIn('locations_test.py', class_name_and_loc) + def test_known_external_unbound_method_as_part_of_class(self): + class_name_and_loc = locations.get_name_and_loc(ExternalObject.a_method) + self.assertIn('ExternalObject.a_method', class_name_and_loc) + self.assertIn('locations_test.py', class_name_and_loc) + + def test_known_bound_method_as_part_of_class(self): + class OtherObject(object): + def a_method(self): + pass + class_name_and_loc = locations.get_name_and_loc(OtherObject().a_method) + self.assertIn('OtherObject.a_method', class_name_and_loc) + self.assertIn('locations_test.py', class_name_and_loc) + + def test_known_external_bound_method_as_part_of_class(self): + class_name_and_loc = locations.get_name_and_loc(ExternalObject().a_method) + self.assertIn('ExternalObject.a_method', class_name_and_loc) + self.assertIn('locations_test.py', class_name_and_loc) + def test_unknown(self): unknown_class = type('UnknownClass', (object,), {}) class_name_and_loc = locations.get_name_and_loc(unknown_class) self.assertEqual('tests.locations_test.UnknownClass', class_name_and_loc) +class GetTypeName(unittest.TestCase): + + def test_get_type_name_unbound_method(self): + class SomeObject(object): + def a_method(self): + pass + self.assertEqual('SomeObject', locations._get_type_name(SomeObject.a_method)) + + def test_get_type_name_external_unbound_method(self): + self.assertEqual('ExternalObject', locations._get_type_name(ExternalObject.a_method)) + + def test_get_type_name_bound_method(self): + class SomeObject(object): + def a_method(self): + pass + self.assertEqual('SomeObject', locations._get_type_name(SomeObject().a_method)) + + def test_get_type_name_external_bound_method(self): + self.assertEqual('ExternalObject', locations._get_type_name(ExternalObject().a_method)) + + def test_get_type_name_function(self): + def method(): + pass + self.assertEqual('tests.locations_test', locations._get_type_name(method)) + + class GetBackFrameLocTest(unittest.TestCase): def test_correct_file_and_line(self): diff --git a/tests/object_graph_test.py b/tests/object_graph_test.py index c435a88..ba7a52c 100644 --- a/tests/object_graph_test.py +++ b/tests/object_graph_test.py @@ -14,8 +14,6 @@ """ -import inspect -import types import unittest from pinject import bindings @@ -230,68 +228,6 @@ class _Foo(object): binding_specs=[SomeBindingSpec()]) -class VerifyTypeTest(unittest.TestCase): - - def test_verifies_correct_type_ok(self): - object_graph._verify_type(types, types.ModuleType, 'unused') - - def test_raises_exception_if_incorrect_type(self): - self.assertRaises(errors.WrongArgTypeError, object_graph._verify_type, - 'not-a-module', types.ModuleType, 'an-arg-name') - - -class VerifyTypesTest(unittest.TestCase): - - def test_verifies_empty_sequence_ok(self): - object_graph._verify_types([], types.ModuleType, 'unused') - - def test_verifies_correct_type_ok(self): - object_graph._verify_types([types], types.ModuleType, 'unused') - - def test_raises_exception_if_not_sequence(self): - self.assertRaises(errors.WrongArgTypeError, object_graph._verify_types, - 42, types.ModuleType, 'an-arg-name') - - def test_raises_exception_if_element_is_incorrect_type(self): - self.assertRaises(errors.WrongArgElementTypeError, - object_graph._verify_types, - ['not-a-module'], types.ModuleType, 'an-arg-name') - - -class VerifySubclassesTest(unittest.TestCase): - - def test_verifies_empty_sequence_ok(self): - object_graph._verify_subclasses([], bindings.BindingSpec, 'unused') - - def test_verifies_correct_type_ok(self): - class SomeBindingSpec(bindings.BindingSpec): - pass - object_graph._verify_subclasses( - [SomeBindingSpec()], bindings.BindingSpec, 'unused') - - def test_raises_exception_if_not_sequence(self): - self.assertRaises(errors.WrongArgTypeError, - object_graph._verify_subclasses, - 42, bindings.BindingSpec, 'an-arg-name') - - def test_raises_exception_if_element_is_not_subclass(self): - class NotBindingSpec(object): - pass - self.assertRaises( - errors.WrongArgElementTypeError, object_graph._verify_subclasses, - [NotBindingSpec()], bindings.BindingSpec, 'an-arg-name') - - -class VerifyCallableTest(unittest.TestCase): - - def test_verifies_callable_ok(self): - object_graph._verify_callable(lambda: None, 'unused') - - def test_raises_exception_if_not_callable(self): - self.assertRaises(errors.WrongArgTypeError, - object_graph._verify_callable, 42, 'an-arg-name') - - class PareToPresentArgsTest(unittest.TestCase): def test_removes_only_args_not_present(self): diff --git a/tests/object_providers_test.py b/tests/object_providers_test.py index 41aeeed..b3fe4ac 100644 --- a/tests/object_providers_test.py +++ b/tests/object_providers_test.py @@ -14,22 +14,18 @@ """ -import inspect import unittest from pinject import arg_binding_keys -from pinject import binding_keys from pinject import bindings from pinject import decorators from pinject import errors from pinject import injection_contexts from pinject import object_providers from pinject import scoping -from nose.tools import nottest -@nottest -def new_test_obj_provider(arg_binding_key, instance, allow_injecting_none=True): +def new_obj_provider(arg_binding_key, instance, allow_injecting_none=True): binding_key = arg_binding_key.binding_key binding = bindings.new_binding_to_instance( binding_key, instance, 'a-scope', lambda: 'unused-desc') @@ -40,7 +36,6 @@ def new_test_obj_provider(arg_binding_key, instance, allow_injecting_none=True): binding_mapping, bindable_scopes, allow_injecting_none) -@nottest def new_injection_context(): return injection_contexts.InjectionContextFactory(lambda _1, _2: True).new( lambda: None) @@ -53,7 +48,7 @@ class ObjectProviderTest(unittest.TestCase): def test_provides_from_arg_binding_key_successfully(self): arg_binding_key = arg_binding_keys.new('an-arg-name') - obj_provider = new_test_obj_provider(arg_binding_key, 'an-instance') + obj_provider = new_obj_provider(arg_binding_key, 'an-instance') self.assertEqual('an-instance', obj_provider.provide_from_arg_binding_key( _UNUSED_INJECTION_SITE_FN, @@ -61,7 +56,7 @@ def test_provides_from_arg_binding_key_successfully(self): def test_provides_provider_fn_from_arg_binding_key_successfully(self): arg_binding_key = arg_binding_keys.new('provide_foo') - obj_provider = new_test_obj_provider(arg_binding_key, 'an-instance') + obj_provider = new_obj_provider(arg_binding_key, 'an-instance') provide_fn = obj_provider.provide_from_arg_binding_key( _UNUSED_INJECTION_SITE_FN, arg_binding_key, new_injection_context()) @@ -69,14 +64,14 @@ def test_provides_provider_fn_from_arg_binding_key_successfully(self): def test_can_provide_none_from_arg_binding_key_when_allowed(self): arg_binding_key = arg_binding_keys.new('an-arg-name') - obj_provider = new_test_obj_provider(arg_binding_key, None) + obj_provider = new_obj_provider(arg_binding_key, None) self.assertIsNone(obj_provider.provide_from_arg_binding_key( _UNUSED_INJECTION_SITE_FN, arg_binding_key, new_injection_context())) def test_cannot_provide_none_from_binding_key_when_disallowed(self): arg_binding_key = arg_binding_keys.new('an-arg-name') - obj_provider = new_test_obj_provider(arg_binding_key, None, + obj_provider = new_obj_provider(arg_binding_key, None, allow_injecting_none=False) self.assertRaises(errors.InjectingNoneDisallowedError, obj_provider.provide_from_arg_binding_key, @@ -88,7 +83,7 @@ class Foo(object): def __init__(self, bar): self.bar = bar arg_binding_key = arg_binding_keys.new('bar') - obj_provider = new_test_obj_provider(arg_binding_key, 'a-bar') + obj_provider = new_obj_provider(arg_binding_key, 'a-bar') foo = obj_provider.provide_class(Foo, new_injection_context(), [], {}) self.assertEqual('a-bar', foo.bar) @@ -97,7 +92,7 @@ class SomeClass(object): @decorators.inject(['baz']) def __init__(self, foo, bar, baz): self.foobarbaz = foo + bar + baz - obj_provider = new_test_obj_provider(arg_binding_keys.new('baz'), 'baz') + obj_provider = new_obj_provider(arg_binding_keys.new('baz'), 'baz') some_class = obj_provider.provide_class( SomeClass, new_injection_context(), ['foo'], {'bar': 'bar'}) self.assertEqual('foobarbaz', some_class.foobarbaz) @@ -106,7 +101,7 @@ def test_provides_class_with_init_as_method_wrapper_successfully(self): class Foo(object): pass arg_binding_key = arg_binding_keys.new('unused') - obj_provider = new_test_obj_provider(arg_binding_key, 'unused') + obj_provider = new_obj_provider(arg_binding_key, 'unused') self.assertIsInstance( obj_provider.provide_class(Foo, new_injection_context(), [], {}), Foo) @@ -115,7 +110,7 @@ def test_calls_with_injection_successfully(self): def foo(bar): return 'a-foo-and-' + bar arg_binding_key = arg_binding_keys.new('bar') - obj_provider = new_test_obj_provider(arg_binding_key, 'a-bar') + obj_provider = new_obj_provider(arg_binding_key, 'a-bar') self.assertEqual('a-foo-and-a-bar', obj_provider.call_with_injection( foo, new_injection_context(), [], {})) @@ -124,7 +119,7 @@ def test_gets_injection_kwargs_successfully(self): def foo(bar): pass arg_binding_key = arg_binding_keys.new('bar') - obj_provider = new_test_obj_provider(arg_binding_key, 'a-bar') + obj_provider = new_obj_provider(arg_binding_key, 'a-bar') pargs, kwargs = obj_provider.get_injection_pargs_kwargs( foo, new_injection_context(), [], {}) self.assertEqual([], pargs) diff --git a/tests/scoping_test.py b/tests/scoping_test.py index 728abfa..1cbc417 100644 --- a/tests/scoping_test.py +++ b/tests/scoping_test.py @@ -33,8 +33,8 @@ def provider_fn(): scope = scoping.PrototypeScope() binding_key = binding_keys.new('unused') self.assertEqual( - range(10), - [scope.provide(binding_key, provider_fn) for _ in xrange(10)]) + list(range(10)), + [scope.provide(binding_key, provider_fn) for _ in range(10)]) class SingletonScopeTest(unittest.TestCase): @@ -74,7 +74,7 @@ def test_adds_default_scopes_to_given_scopes(self): def test_returns_default_scopes_if_none_given(self): id_to_scope = scoping.get_id_to_scope_with_defaults() - self.assertEqual(set([scoping.SINGLETON, scoping.PROTOTYPE]), + self.assertEqual({scoping.SINGLETON, scoping.PROTOTYPE}, set(id_to_scope.keys())) def test_does_not_allow_overriding_prototype_scope(self): diff --git a/tests/support_test.py b/tests/support_test.py new file mode 100644 index 0000000..b5fcd82 --- /dev/null +++ b/tests/support_test.py @@ -0,0 +1,193 @@ +"""Copyright 2013 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +import unittest +import types +import inspect + +from pinject import support +from pinject import bindings +from pinject import errors + + +class VerifyTypeTest(unittest.TestCase): + + def test_verifies_correct_type_ok(self): + support._verify_type(inspect.ismodule, types, 'unused', 'module') + + def test_raises_exception_if_incorrect_type(self): + self.assertRaises(errors.WrongArgTypeError, support._verify_type, + inspect.ismodule, 'not-a-module', + 'an-arg-name', 'module') + + +class VerifyTypesTest(unittest.TestCase): + + def test_verifies_empty_sequence_ok(self): + def fn_checker(elt): + return True + support._verify_types(fn_checker, [], 'unused', 'no_type') + + def test_verifies_correct_type_ok(self): + support._verify_types(inspect.ismodule, [types], 'unused', 'module') + + def test_raises_exception_if_not_sequence(self): + self.assertRaises(errors.WrongArgTypeError, support._verify_types, + inspect.ismodule, 42, 'an-arg-name', 'module') + + def test_raises_exception_if_element_is_incorrect_type(self): + self.assertRaises(errors.WrongArgElementTypeError, + support._verify_types, inspect.ismodule, + ['not-a-module'], 'an-arg-name', 'module') + + +class VerifySubclassesTest(unittest.TestCase): + + def test_verifies_empty_sequence_ok(self): + support.verify_subclasses([], bindings.BindingSpec, 'unused') + + def test_verifies_correct_type_ok(self): + class SomeBindingSpec(bindings.BindingSpec): + pass + support.verify_subclasses( + [SomeBindingSpec()], bindings.BindingSpec, 'unused') + + def test_raises_exception_if_not_sequence(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_subclasses, + 42, bindings.BindingSpec, 'an-arg-name') + + def test_raises_exception_if_element_is_not_subclass(self): + class NotBindingSpec(object): + pass + self.assertRaises( + errors.WrongArgElementTypeError, support.verify_subclasses, + [NotBindingSpec()], bindings.BindingSpec, 'an-arg-name') + + +class VerifyCallableTest(unittest.TestCase): + + def test_verifies_callable_ok(self): + support.verify_callable(lambda: None, 'unused') + + def test_raises_exception_if_not_callable(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_callable, 42, 'an-arg-name') + + +class VerifyModuleTypesTest(unittest.TestCase): + + def test_verifies_module_types_ok(self): + support.verify_module_types([types], 'unused') + + def test_raises_exception_if_not_module_types(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_module_types, 42, 'an-arg-name') + + +class VerifyClassTypesTest(unittest.TestCase): + + def test_verifies_module_types_ok(self): + class Foo(object): + pass + support.verify_class_types([Foo], 'unused') + + def test_raises_exception_if_not_class_types(self): + self.assertRaises(errors.WrongArgTypeError, + support.verify_class_types, 42, 'an-arg-name') + + +class IsSequenceTest(unittest.TestCase): + + def test_argument_identified_as_sequence_instance(self): + self.assertTrue(support.is_sequence(list())) + + def test_argument_identified_as_not_sequence_instance(self): + self.assertFalse(support.is_sequence(None)) + + +class IsStringTest(unittest.TestCase): + + def test_argument_identified_as_string_instance(self): + self.assertTrue(support.is_string('some string')) + + def test_argument_identified_as_not_string_instance(self): + self.assertFalse(support.is_string(None)) + + +class IsConstructorDefinedTest(unittest.TestCase): + + def test_constructor_present_detection(self): + class Foo(object): + def __init__(self): + pass + self.assertTrue(support.is_constructor_defined(Foo)) + + def test_constructor_not_present_detection(self): + class Foo(object): + pass + self.assertFalse(support.is_constructor_defined(Foo)) + + +class GetMethodArgsTest(unittest.TestCase): + + def test_get_method_args_no_params(self): + def simple(): + pass + arg_names, varargs, keywords, defaults = support.get_method_args(simple) + self.assertEqual([], arg_names) + self.assertEqual(None, varargs) + self.assertEqual(None, keywords) + self.assertEqual(None, defaults) + + def test_get_method_args_single_params(self): + def simple(arg1): + pass + arg_names, varargs, keywords, defaults = support.get_method_args(simple) + self.assertEqual(['arg1'], arg_names) + self.assertEqual(None, varargs) + self.assertEqual(None, keywords) + self.assertEqual(None, defaults) + + def test_get_method_args_single_default_params(self): + def simple(arg1, arg2='foo'): + pass + arg_names, varargs, keywords, defaults = support.get_method_args(simple) + self.assertEqual(['arg1', 'arg2'], arg_names) + self.assertEqual(None, varargs) + self.assertEqual(None, keywords) + self.assertEqual(('foo',), defaults) + + def test_get_method_args_varargs(self): + def simple(arg1, arg2='foo', *args): + pass + arg_names, varargs, keywords, defaults = support.get_method_args(simple) + self.assertEqual(['arg1', 'arg2'], arg_names) + self.assertEqual('args', varargs) + self.assertEqual(None, keywords) + self.assertEqual(('foo',), defaults) + + def test_get_method_args_kvargs(self): + def simple(arg1, arg2='foo', *args, **kvargs): + pass + arg_names, varargs, keywords, defaults = support.get_method_args(simple) + self.assertEqual(['arg1', 'arg2'], arg_names) + self.assertEqual('args', varargs) + self.assertEqual('kvargs', keywords) + self.assertEqual(('foo',), defaults) + + def test_raises_exception_if_not_method(self): + self.assertRaises(TypeError, support.get_method_args, None) diff --git a/test_errors.py b/tests/test_errors.py similarity index 99% rename from test_errors.py rename to tests/test_errors.py index 1398041..2eee371 100755 --- a/test_errors.py +++ b/tests/test_errors.py @@ -16,6 +16,7 @@ """ +import six import inspect import sys import traceback @@ -354,6 +355,6 @@ def print_wrong_arg_type_error(): all_print_method_pairs.sort(key=lambda x: x[0]) all_print_methods = [value for name, value in all_print_method_pairs] for print_method in all_print_methods: - print '#' * 78 + six.print_('#' * 78) print_method() -print '#' * 78 +six.print_('#' * 78)