diff --git a/docs/loaders.rst b/docs/loaders.rst index 16ca823..5e9d19d 100644 --- a/docs/loaders.rst +++ b/docs/loaders.rst @@ -83,3 +83,21 @@ load all the objects from the modules, rather than the modules themselves. plugins = ObjectLoader().load('myplugins') +Handling Import Errors +---------------------- + +If import errors are encountered while loading plugins, they are +captured and plugin loading continues. These exceptions can be +retrieved from the object returned by ``load()``. + +Example: + +:: + + from straight.plugin import load + + plugins = load('myplugins') + + if plugins.has_exceptions(): + plugin_exceptions = plugins.exceptions() + print(plugin_exceptions) diff --git a/straight/plugin/loaders.py b/straight/plugin/loaders.py index 1a17732..4d57857 100644 --- a/straight/plugin/loaders.py +++ b/straight/plugin/loaders.py @@ -22,6 +22,7 @@ class Loader(object): def __init__(self, *args, **kwargs): self._cache = [] self.loaded = False + self._exceptions = [] def _fill_cache(self, *args, **kwargs): raise NotImplementedError() @@ -32,7 +33,7 @@ def load(self, *args, **kwargs): self._post_fill() self._order() self.loaded = True - return PluginManager(self._cache) + return PluginManager(self._cache, self._exceptions) def _meta(self, plugin): meta = getattr(plugin, "__plugin__", None) @@ -114,9 +115,8 @@ def _findPluginModules(self, namespace): try: module = import_module(import_path) - except ImportError: - # raise Exception(import_path) - + except ImportError as import_error: + self._exceptions.append(import_error) module = None if module is not None: @@ -151,6 +151,7 @@ def _fill_cache(self, namespace): objects.append(getattr(module, attr_name)) self._cache = objects + self._exceptions = modules.exceptions() return objects diff --git a/straight/plugin/manager.py b/straight/plugin/manager.py index 811d04b..82f9abd 100644 --- a/straight/plugin/manager.py +++ b/straight/plugin/manager.py @@ -1,6 +1,7 @@ class PluginManager(object): - def __init__(self, plugins): + def __init__(self, plugins, exceptions): self._plugins = plugins + self._exceptions = exceptions def __iter__(self): return iter(self._plugins) @@ -20,7 +21,7 @@ def produce(self, *args, **kwargs): for p in self._plugins: r = p(*args, **kwargs) new_plugins.append(r) - return PluginManager(new_plugins) + return PluginManager(new_plugins, exceptions = []) def call(self, methodname, *args, **kwargs): """Call a common method on all the plugins, if it exists.""" @@ -58,3 +59,9 @@ def pipe(self, methodname, first_arg, *args, **kwargs): if r is not None: first_arg = r return r + + def exceptions(self): + return self._exceptions + + def has_exceptions(self): + return len(self._exceptions) > 0 diff --git a/test-packages/import-error-test-plugins/testplugin/__init__.py b/test-packages/import-error-test-plugins/testplugin/__init__.py new file mode 100644 index 0000000..b36383a --- /dev/null +++ b/test-packages/import-error-test-plugins/testplugin/__init__.py @@ -0,0 +1,3 @@ +from pkgutil import extend_path + +__path__ = extend_path(__path__, __name__) diff --git a/test-packages/import-error-test-plugins/testplugin/error.py b/test-packages/import-error-test-plugins/testplugin/error.py new file mode 100644 index 0000000..7bd44c6 --- /dev/null +++ b/test-packages/import-error-test-plugins/testplugin/error.py @@ -0,0 +1,4 @@ +import nonexistent.package + +def do(x): + return x + 1 diff --git a/tests.py b/tests.py index 841a33a..4170e93 100755 --- a/tests.py +++ b/tests.py @@ -124,6 +124,7 @@ class ObjectLoaderTestCase(LoaderTestCaseMixin, unittest.TestCase): paths = ( os.path.join(os.path.dirname(__file__), "test-packages", "more-test-plugins"), os.path.join(os.path.dirname(__file__), "test-packages", "some-test-plugins"), + os.path.join(os.path.dirname(__file__), "test-packages", "import-error-test-plugins"), ) def setUp(self): @@ -131,8 +132,10 @@ def setUp(self): super(ObjectLoaderTestCase, self).setUp() def test_load_all(self): - objects = list(self.loader.load("testplugin")) + plugins = self.loader.load("testplugin") + objects = list(plugins) self.assertEqual(len(objects), 2, str(objects)[:100] + " ...") + self.assertEqual(len(plugins.exceptions()), 1) class ClassLoaderTestCase(LoaderTestCaseMixin, unittest.TestCase): @@ -243,7 +246,7 @@ def test_plugin(self): class PluginManagerTestCase(unittest.TestCase): def setUp(self): - self.m = manager.PluginManager([mock.Mock(), mock.Mock()]) + self.m = manager.PluginManager([mock.Mock(), mock.Mock()], []) def test_first(self): self.m._plugins[0].x.return_value = 1 @@ -261,7 +264,7 @@ def plus_one(x): self.assertEqual(3, self.m.pipe("x", 1)) def test_pipe_no_plugins_found(self): - no_plugins = manager.PluginManager([]) + no_plugins = manager.PluginManager([], []) self.assertEqual(1, no_plugins.pipe("x", 1)) def test_call(self): @@ -276,6 +279,10 @@ def test_produce(self): assert products[1] is self.m._plugins[1].return_value self.m._plugins[1].called_with(1, 2) + def test_exceptions(self): + encountered_exceptions = manager.PluginManager([], [mock.Mock()]) + self.assertTrue(encountered_exceptions.has_exceptions()) + self.assertEqual(len(encountered_exceptions.exceptions()), 1) if __name__ == "__main__": unittest.main()