From 3714bd5e11b0a2840d6453a6ae4a2de6713c374f Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 10:13:18 +0000 Subject: [PATCH 01/30] Fixes for array data source tests. --- chaco/tests/arraydatasource_test_case.py | 156 +++++++++++++++++++---- 1 file changed, 130 insertions(+), 26 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index 1ca889637..908115f63 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -4,59 +4,163 @@ import unittest -from numpy import arange, array, allclose, empty, isnan, nan +from numpy import arange, array, allclose, empty, isnan, nan, ones +from numpy.testing import assert_almost_equal, assert_array_equal import numpy as np from chaco.api import ArrayDataSource, PointDataSource +from traits.testing.unittest_tools import UnittestTools -class ArrayDataTestCase(unittest.TestCase): - def test_basic_set_get(self): +class ArrayDataTestCase(UnittestTools, unittest.TestCase): + + def test_init_defaults(self): + data_source = ArrayDataSource() + assert_array_equal(data_source._data, []) + self.assertEqual(data_source.value_dimension, "scalar") + self.assertEqual(data_source.sort_order, "none") + self.assertFalse(data_source.is_masked()) + + def test_basic_setup(self): myarray = arange(10) - sd = ArrayDataSource(myarray) - self.assertTrue(allclose(myarray, sd._data)) - self.assert_(sd.value_dimension == "scalar") - return + data_source = ArrayDataSource(myarray) + assert_array_equal(myarray, data_source._data) + self.assertEqual(data_source.value_dimension, "scalar") + self.assertEqual(data_source.sort_order, "none") + self.assertFalse(data_source.is_masked()) + + def test_set_data(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + new_array = arange(0, 20, 2) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.set_data(new_array) + + assert_array_equal(new_array, data_source._data) + self.assertEqual(data_source.get_bounds(), (0, 18)) + self.assertEqual(data_source.sort_order, "none") + + def test_set_data_ordered(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + new_array = arange(20, 0, -2) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.set_data(new_array, sort_order='descending') + + assert_array_equal(new_array, data_source._data) + self.assertEqual(data_source.get_bounds(), (2, 20)) + self.assertEqual(data_source.sort_order, "descending") + + def test_set_mask(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + mymask = array([i % 2 for i in myarray], dtype=bool) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.set_mask(mymask) + + assert_array_equal(myarray, data_source._data) + assert_array_equal(mymask, data_source._cached_mask) + self.assertTrue(data_source.is_masked()) + self.assertEqual(data_source.get_bounds(), (0, 9)) + + def test_remove_mask(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + mymask = array([i % 2 for i in myarray], dtype=bool) + data_source.set_mask(mymask) + self.assertTrue(data_source.is_masked()) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.remove_mask() + + assert_array_equal(myarray, data_source._data) + self.assertIsNone(data_source._cached_mask, None) + self.assertFalse(data_source.is_masked()) + self.assertEqual(data_source.get_bounds(), (0, 9)) + + def test_get_data(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + + assert_array_equal(myarray, data_source.get_data()) + + def test_get_data_no_data(self): + data_source = ArrayDataSource(None) + + assert_array_equal(data_source.get_data(), 0.0) + + def test_get_data_mask(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + mymask = array([i % 2 for i in myarray], dtype=bool) + data_source.set_mask(mymask) + + data, mask = data_source.get_data_mask() + assert_array_equal(data, myarray) + + @unittest.skip('get_data_mask() fails in this case') + def test_get_data_mask_no_data(self): + data_source = ArrayDataSource(None) + + data, mask = data_source.get_data_mask() + # XXX this is what I would expect, given get_data() behaviour + assert_array_equal(data, []) + assert_array_equal(data, []) + + def test_get_data_mask_no_mask(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + + data, mask = data_source.get_data_mask() + assert_array_equal(data, myarray) + assert_array_equal(mask, ones(shape=10, dtype=bool)) def test_bounds(self): # ascending myarray = arange(10) - sd = ArrayDataSource(myarray, sort_order="ascending") - bounds = sd.get_bounds() - self.assert_(bounds == (0,9)) + data_source = ArrayDataSource(myarray, sort_order="ascending") + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 9)) # descending myarray = arange(10)[::-1] - sd = ArrayDataSource(myarray, sort_order="descending") - bounds = sd.get_bounds() - self.assert_(bounds == (0,9)) + data_source = ArrayDataSource(myarray, sort_order="descending") + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 9)) # no order - myarray = array([12,3,0,9,2,18,3]) - sd = ArrayDataSource(myarray, sort_order="none") - bounds = sd.get_bounds() - self.assert_(bounds == (0,18)) - return + myarray = array([12, 3, 0, 9, 2, 18, 3]) + data_source = ArrayDataSource(myarray, sort_order="none") + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 18)) + + def test_bounds_empty(self): + data_source = ArrayDataSource() + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 0)) def test_data_size(self): - # We know that ScalarData always returns the exact length of its data + # We know that ArrayDataTestCase always returns the exact length of + # its data myarray = arange(913) - sd = ArrayDataSource(myarray) - self.assert_(len(myarray) == sd.get_size()) - return + data_source = ArrayDataSource(myarray) + self.assertEqual(len(myarray), data_source.get_size()) def test_bounds_all_nans(self): myarray = empty(10) myarray[:] = nan - sd = ArrayDataSource(myarray) - bounds = sd.get_bounds() + data_source = ArrayDataSource(myarray) + bounds = data_source.get_bounds() self.assertTrue(isnan(bounds[0])) self.assertTrue(isnan(bounds[1])) def test_bounds_non_numeric(self): myarray = np.array([u'abc', u'foo', u'bar', u'def'], dtype=unicode) - sd = ArrayDataSource(myarray) - bounds = sd.get_bounds() + data_source = ArrayDataSource(myarray) + bounds = data_source.get_bounds() self.assertEqual(bounds, (u'abc', u'def')) From d9ba227fd51d4dedd017caf22c3795d35c66b098 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 10:38:46 +0000 Subject: [PATCH 02/30] Use unitest2. --- chaco/tests/arraydatasource_test_case.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index 908115f63..b51cbfff3 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -2,10 +2,10 @@ Test of basic dataseries behavior. """ -import unittest +import unittest2 as unittest from numpy import arange, array, allclose, empty, isnan, nan, ones -from numpy.testing import assert_almost_equal, assert_array_equal +from numpy.testing import assert_array_equal import numpy as np from chaco.api import ArrayDataSource, PointDataSource From 8bfa7023f7f2079b389c53a22e09233379ec8ee7 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 13:58:26 +0000 Subject: [PATCH 03/30] More robust checking of get_bounds() --- chaco/tests/arraydatasource_test_case.py | 52 ++++++++++++++++++++---- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index b51cbfff3..3dfc0baa9 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -137,32 +137,66 @@ def test_bounds(self): bounds = data_source.get_bounds() self.assertEqual(bounds, (0, 18)) + def test_bounds_length_one(self): + # this is special-cased in the code, so exercise the code path + data_source = ArrayDataSource(array([1.0])) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (1.0, 1.0)) + + def test_bounds_length_zero(self): + # this is special-cased in the code, so exercise the code path + data_source = ArrayDataSource(array([])) + bounds = data_source.get_bounds() + # XXX this is sort of inconsistent with test_bounds_all_nan() + self.assertEqual(bounds, (0, 0)) + def test_bounds_empty(self): data_source = ArrayDataSource() bounds = data_source.get_bounds() + # XXX this is sort of inconsistent with test_bounds_all_nan() self.assertEqual(bounds, (0, 0)) - def test_data_size(self): - # We know that ArrayDataTestCase always returns the exact length of - # its data - myarray = arange(913) - data_source = ArrayDataSource(myarray) - self.assertEqual(len(myarray), data_source.get_size()) - def test_bounds_all_nans(self): myarray = empty(10) myarray[:] = nan - data_source = ArrayDataSource(myarray) - bounds = data_source.get_bounds() + sd = ArrayDataSource(myarray) + bounds = sd.get_bounds() self.assertTrue(isnan(bounds[0])) self.assertTrue(isnan(bounds[1])) + def test_bounds_some_nan(self): + data_source = ArrayDataSource(array([np.nan, 3, 0, 9, np.nan, 18, 3])) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 18)) + + def test_bounds_negative_inf(self): + data_source = ArrayDataSource(array([12, 3, -np.inf, 9, 2, 18, 3])) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (-np.inf, 18)) + + def test_bounds_positive_inf(self): + data_source = ArrayDataSource(array([12, 3, 0, 9, 2, np.inf, 3])) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, np.inf)) + + def test_bounds_negative_positive_inf(self): + data_source = ArrayDataSource(array([12, 3, -np.inf, 9, 2, np.inf, 3])) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (-np.inf, np.inf)) + def test_bounds_non_numeric(self): myarray = np.array([u'abc', u'foo', u'bar', u'def'], dtype=unicode) data_source = ArrayDataSource(myarray) bounds = data_source.get_bounds() self.assertEqual(bounds, (u'abc', u'def')) + def test_data_size(self): + # We know that ArrayDataTestCase always returns the exact length of + # its data + myarray = arange(913) + data_source = ArrayDataSource(myarray) + self.assertEqual(len(myarray), data_source.get_size()) + class PointDataTestCase(unittest.TestCase): # Since PointData is mostly the same as ScalarData, the key things to From 6edd34e0f5e6178bcdb12924acb81f9d5fe034ab Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 14:03:13 +0000 Subject: [PATCH 04/30] Fix bare except statement. --- chaco/array_data_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chaco/array_data_source.py b/chaco/array_data_source.py index 372f437cf..f61a99821 100644 --- a/chaco/array_data_source.py +++ b/chaco/array_data_source.py @@ -251,7 +251,7 @@ def _compute_bounds(self, data=None): data_len = 0 try: data_len = len(data) - except: + except Exception: pass if data_len == 0: self._min_index = 0 From 262c2bc45acd39c067bf6547258afd0ec954cb10 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 14:49:35 +0000 Subject: [PATCH 05/30] Add image data tests, minor fixes for array data tests. --- chaco/tests/arraydatasource_test_case.py | 9 +- chaco/tests/image_data_test_case.py | 128 +++++++++++++++++++++++ 2 files changed, 133 insertions(+), 4 deletions(-) create mode 100644 chaco/tests/image_data_test_case.py diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index 3dfc0baa9..bb95430ed 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -90,6 +90,7 @@ def test_get_data(self): def test_get_data_no_data(self): data_source = ArrayDataSource(None) + # XXX A _scalar_? Not array([]) or None? assert_array_equal(data_source.get_data(), 0.0) def test_get_data_mask(self): @@ -107,8 +108,8 @@ def test_get_data_mask_no_data(self): data, mask = data_source.get_data_mask() # XXX this is what I would expect, given get_data() behaviour - assert_array_equal(data, []) - assert_array_equal(data, []) + assert_array_equal(data, 0.0) + assert_array_equal(data, True) def test_get_data_mask_no_mask(self): myarray = arange(10) @@ -147,13 +148,13 @@ def test_bounds_length_zero(self): # this is special-cased in the code, so exercise the code path data_source = ArrayDataSource(array([])) bounds = data_source.get_bounds() - # XXX this is sort of inconsistent with test_bounds_all_nan() + # XXX this is sort of inconsistent with test_bounds_all_nans() self.assertEqual(bounds, (0, 0)) def test_bounds_empty(self): data_source = ArrayDataSource() bounds = data_source.get_bounds() - # XXX this is sort of inconsistent with test_bounds_all_nan() + # XXX this is sort of inconsistent with test_bounds_all_nans() self.assertEqual(bounds, (0, 0)) def test_bounds_all_nans(self): diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py new file mode 100644 index 000000000..dced00eec --- /dev/null +++ b/chaco/tests/image_data_test_case.py @@ -0,0 +1,128 @@ +""" +Test of basic dataseries behavior. +""" + +import unittest + +from numpy import arange, swapaxes +from numpy.testing import assert_array_equal + +from chaco.api import ImageData +from traits.testing.unittest_tools import UnittestTools + + +class ArrayDataTestCase(UnittestTools, unittest.TestCase): + + def test_init_defaults(self): + data_source = ImageData() + assert_array_equal(data_source.data, []) + # this isn't right - + #self.assertEqual(data_source.value_dimension, "scalar") + #self.assertEqual(data_source.image_dimension, "image") + + def test_basic_setup(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + assert_array_equal(myarray, data_source.data) + #self.assertEqual(data_source.value_dimension, "scalar") + self.assertFalse(data_source.is_masked()) + + def test_set_data(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + new_array = arange(0, 30, 2).reshape(5, 3, 1) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.set_data(new_array) + + assert_array_equal(new_array, data_source.data) + self.assertEqual(data_source.get_bounds(), (0, 28)) + + def test_get_data(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + assert_array_equal(myarray, data_source.get_data()) + + def test_get_data_no_data(self): + data_source = ImageData() + + self.assertIsNone(data_source.get_data()) + + def test_get_data_transposed(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray, transposed=True) + + assert_array_equal(swapaxes(myarray, 0, 1), data_source.get_data()) + + def test_get_data_mask(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + # XXX this is probably not the right thing + with self.assertRaises(NotImplementedError): + data, mask = data_source.get_data_mask() + + def test_get_data_mask_no_data(self): + data_source = ImageData() + + # XXX this is probably not the right thing + with self.assertRaises(NotImplementedError): + data, mask = data_source.get_data_mask() + + def test_bounds(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 14)) + + @unittest.skip('test_bounds_empty() fails in this case') + def test_bounds_empty(self): + data_source = ImageData() + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 0)) + + def test_data_size(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + self.assertEqual(15, data_source.get_size()) + + def test_data_size_no_data(self): + data_source = ImageData() + self.assertEqual(0, data_source.get_size()) + + def test_get_width(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + self.assertEqual(3, data_source.get_width()) + + def test_get_width_transposed(self): + myarray = arange(15).reshape(5, 3) + data_source = ImageData(data=myarray, transposed=True) + + self.assertEqual(5, data_source.get_width()) + + def test_get_height(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + self.assertEqual(5, data_source.get_height()) + + def test_get_height_transposed(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray, transposed=True) + + self.assertEqual(3, data_source.get_height()) + + def test_array_bounds(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + self.assertEqual(((0, 3), (0, 5)), data_source.get_array_bounds()) + + def test_array_bounds_transposed(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray, transposed=True) + + self.assertEqual(((0, 5), (0, 3)), data_source.get_array_bounds()) From ffa0e9ed62a461e9f689099404b8aada29c6258c Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 14:59:39 +0000 Subject: [PATCH 06/30] Replace unittest with unittest2 --- chaco/tests/image_data_test_case.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py index dced00eec..c93420908 100644 --- a/chaco/tests/image_data_test_case.py +++ b/chaco/tests/image_data_test_case.py @@ -2,7 +2,7 @@ Test of basic dataseries behavior. """ -import unittest +import unittest2 as unittest from numpy import arange, swapaxes from numpy.testing import assert_array_equal From 5e16075dc746eb59e18146464a532c23dabbc379 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 16:50:34 +0000 Subject: [PATCH 07/30] Add image reading tests for ImageData. --- MANIFEST.in | 2 ++ chaco/tests/data/PngSuite/LICENSE.txt | 8 ++++++++ chaco/tests/data/PngSuite/basi6a08.png | Bin 0 -> 361 bytes chaco/tests/data/PngSuite/basn2c08.png | Bin 0 -> 145 bytes chaco/tests/image_data_test_case.py | 21 ++++++++++++++++++++- image_LICENSE.txt | 4 ++++ setup.py | 7 +++++-- 7 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 chaco/tests/data/PngSuite/LICENSE.txt create mode 100644 chaco/tests/data/PngSuite/basi6a08.png create mode 100644 chaco/tests/data/PngSuite/basn2c08.png diff --git a/MANIFEST.in b/MANIFEST.in index 312ff4687..6242b2f9e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ include chaco/*.h +include chaco/tests/data/PngSuite/*.png +include chaco/tests/data/PngSuite/LICENSE.txt diff --git a/chaco/tests/data/PngSuite/LICENSE.txt b/chaco/tests/data/PngSuite/LICENSE.txt new file mode 100644 index 000000000..6f96cebe0 --- /dev/null +++ b/chaco/tests/data/PngSuite/LICENSE.txt @@ -0,0 +1,8 @@ +PngSuite +-------- + +Permission to use, copy, modify and distribute these images for any +purpose and without fee is hereby granted. + + +(c) Willem van Schaik, 1996, 2011 diff --git a/chaco/tests/data/PngSuite/basi6a08.png b/chaco/tests/data/PngSuite/basi6a08.png new file mode 100644 index 0000000000000000000000000000000000000000..aecb32e0d9e347ccdcab5d7fdad2dde7aef9da8a GIT binary patch literal 361 zcmV-v0ha!WP)!N7@Xl>n0JWRT0P1ZZ0~F7IAF8qd>UAgssMo0sP&xyjf=W~LX-;!O00000NkvXX Hu0mjfr8|%> literal 0 HcmV?d00001 diff --git a/chaco/tests/data/PngSuite/basn2c08.png b/chaco/tests/data/PngSuite/basn2c08.png new file mode 100644 index 0000000000000000000000000000000000000000..db5ad15865f56e48e4bae5b43661d2dbc4e847e3 GIT binary patch literal 145 zcmeAS@N?(olHy`uVBq!ia0vp^3LwnE1SJ1Ryj={WSkfJR9T^zg78t&m77yfmc)B=- zRLpsM^&lsM0S}Wy>zj#xw-*UpyJ&w|_~YC|?Jc}4)(u=DKV%lXeNyXF;n2v$@6|4U rsD)ibtrx; literal 0 HcmV?d00001 diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py index c93420908..02b006fff 100644 --- a/chaco/tests/image_data_test_case.py +++ b/chaco/tests/image_data_test_case.py @@ -2,15 +2,20 @@ Test of basic dataseries behavior. """ -import unittest2 as unittest +import os +import unittest2 as unittest from numpy import arange, swapaxes from numpy.testing import assert_array_equal +from pkg_resources import resource_filename from chaco.api import ImageData from traits.testing.unittest_tools import UnittestTools +data_dir = resource_filename('chaco.tests', 'data') + + class ArrayDataTestCase(UnittestTools, unittest.TestCase): def test_init_defaults(self): @@ -126,3 +131,17 @@ def test_array_bounds_transposed(self): data_source = ImageData(data=myarray, transposed=True) self.assertEqual(((0, 5), (0, 3)), data_source.get_array_bounds()) + + def test_fromfile_png_rgb(self): + # basic smoke test - assume that kiva.image does the right thing + path = os.path.join(data_dir, 'PngSuite', 'basn2c08.png') + data_source = ImageData.fromfile(path) + + self.assertEqual(data_source.value_depth, 3) + + def test_fromfile_png_rgba(self): + # basic smoke test - assume that kiva.image does the right thing + path = os.path.join(data_dir, 'PngSuite', 'basi6a08.png') + data_source = ImageData.fromfile(path) + + self.assertEqual(data_source.value_depth, 4) diff --git a/image_LICENSE.txt b/image_LICENSE.txt index 9ebfe2a01..06416811d 100644 --- a/image_LICENSE.txt +++ b/image_LICENSE.txt @@ -17,3 +17,7 @@ examples/basic: chaco/layers/data: Dialog-error.svg | Tango, CC 2.5, modified to remove gadients Dialog-warning.svg | Tango, CC 2.5, modified to remove gadients + +chaco/tests/data/PngSuite: + basi6a08.png | PngSuite, free use, see PngSuite/LICENSE.txt + basi2c08.png | PngSuite, free use, see PngSuite/LICENSE.txt diff --git a/setup.py b/setup.py index a7362152f..2af55eda3 100644 --- a/setup.py +++ b/setup.py @@ -58,8 +58,11 @@ Topic :: Software Development Topic :: Software Development :: Libraries """.splitlines() if len(c.strip()) > 0], - package_data={'chaco': ['tools/toolbars/images/*.png', - 'layers/data/*.svg']}, + package_data={ + 'chaco': ['tools/toolbars/images/*.png', + 'layers/data/*.svg', + 'tests/data/PngSuite/*.png'] + }, description = 'interactive 2-dimensional plotting', long_description = open('README.rst').read(), download_url = ('http://www.enthought.com/repo/ets/chaco-%s.tar.gz' % From bee330ba5a84a1a53e27e258b28912d2f267e895 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 17:11:59 +0000 Subject: [PATCH 08/30] Ensure libpng is installed on Travis. --- .travis_before_install | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis_before_install b/.travis_before_install index 8e9059b30..ba6d47dd6 100644 --- a/.travis_before_install +++ b/.travis_before_install @@ -3,5 +3,5 @@ export DISPLAY=:99.0 sh -e /etc/init.d/xvfb start -sudo apt-get install python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools +sudo apt-get install libpng python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools sudo apt-get install python-wxgtk2.8 python-wxtools wx2.8-i18n From e7c0551094e1b5e70398c7f296498badd8679a82 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 17:23:41 +0000 Subject: [PATCH 09/30] ...and zlib for travis. --- .travis_before_install | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis_before_install b/.travis_before_install index ba6d47dd6..492f054a8 100644 --- a/.travis_before_install +++ b/.travis_before_install @@ -3,5 +3,5 @@ export DISPLAY=:99.0 sh -e /etc/init.d/xvfb start -sudo apt-get install libpng python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools +sudo apt-get install libpng zlib python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools sudo apt-get install python-wxgtk2.8 python-wxtools wx2.8-i18n From 1f182e225cf49303aee597bc477033b339980ef2 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 17:26:59 +0000 Subject: [PATCH 10/30] try again with Travis --- .travis_before_install | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis_before_install b/.travis_before_install index 492f054a8..7989eb107 100644 --- a/.travis_before_install +++ b/.travis_before_install @@ -3,5 +3,6 @@ export DISPLAY=:99.0 sh -e /etc/init.d/xvfb start -sudo apt-get install libpng zlib python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools +sudo apt-get install libpng zlib +sudo apt-get install python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools sudo apt-get install python-wxgtk2.8 python-wxtools wx2.8-i18n From ddd8c407a107743a4b98ecbfe7c0cb755065c9cb Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 18:42:12 +0000 Subject: [PATCH 11/30] Install PIL via apt get --- .travis_before_install | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis_before_install b/.travis_before_install index 7989eb107..2849a703d 100644 --- a/.travis_before_install +++ b/.travis_before_install @@ -3,6 +3,6 @@ export DISPLAY=:99.0 sh -e /etc/init.d/xvfb start -sudo apt-get install libpng zlib +sudo apt-get install python-pil sudo apt-get install python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools sudo apt-get install python-wxgtk2.8 python-wxtools wx2.8-i18n From 6eeb229972670f671018d45d43f3ad9f132c8056 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 19:03:13 +0000 Subject: [PATCH 12/30] Yet one more attempt to get PIL to includ PNG support. --- .travis.yml | 5 +++++ .travis_before_install | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 7526be145..b1e9439c0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,11 @@ python: before_install: - sudo apt-get update - sudo apt-get install python-numpy swig + # Simlinks for PIL compilation + - sudo ln -s /usr/lib/`uname -i`-linux-gnu/libfreetype.so /usr/lib/ + - sudo ln -s /usr/lib/`uname -i`-linux-gnu/libjpeg.so /usr/lib/ + - sudo ln -s /usr/lib/`uname -i`-linux-gnu/libpng.so /usr/lib/ + - sudo ln -s /usr/lib/`uname -i`-linux-gnu/libz.so /usr/lib/ - source .travis_before_install install: - pip install cython diff --git a/.travis_before_install b/.travis_before_install index 2849a703d..8e9059b30 100644 --- a/.travis_before_install +++ b/.travis_before_install @@ -3,6 +3,5 @@ export DISPLAY=:99.0 sh -e /etc/init.d/xvfb start -sudo apt-get install python-pil sudo apt-get install python-qt4 python-qt4-dev python-sip python-qt4-gl libqt4-scripttools sudo apt-get install python-wxgtk2.8 python-wxtools wx2.8-i18n From 2d9082329f27f67c48d5abf69ca6f52a4b7b1bfa Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 19:57:13 +0000 Subject: [PATCH 13/30] Add tests for function data source. --- chaco/tests/function_data_source_test_case.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 chaco/tests/function_data_source_test_case.py diff --git a/chaco/tests/function_data_source_test_case.py b/chaco/tests/function_data_source_test_case.py new file mode 100644 index 000000000..9faa2d4c9 --- /dev/null +++ b/chaco/tests/function_data_source_test_case.py @@ -0,0 +1,100 @@ +""" +Test of basic dataseries behavior. +""" + +import unittest2 as unittest + +from numpy import array, linspace, nan, ones +from numpy.testing import assert_array_equal +import numpy as np + +from chaco.api import DataRange1D +from chaco.function_data_source import FunctionDataSource +from traits.testing.unittest_tools import UnittestTools + + +class FunctionDataSourceTestCase(UnittestTools, unittest.TestCase): + + def test_init_defaults(self): + data_source = FunctionDataSource() + assert_array_equal(data_source._data, []) + self.assertEqual(data_source.value_dimension, "scalar") + self.assertEqual(data_source.sort_order, "ascending") + self.assertFalse(data_source.is_masked()) + + def test_basic_setup(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + assert_array_equal(myfunc, data_source.func) + self.assertEqual(data_source.value_dimension, "scalar") + self.assertEqual(data_source.sort_order, "ascending") + self.assertFalse(data_source.is_masked()) + + def test_set_data(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + + with self.assertRaises(RuntimeError): + data_source.set_data(lambda low, high: linspace(low, high, 101)) + + def test_set_mask(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + mymask = array([i % 2 for i in xrange(101)], dtype=bool) + + with self.assertRaises(NotImplementedError): + data_source.set_mask(mymask) + + def test_remove_mask(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + + with self.assertRaises(NotImplementedError): + data_source.remove_mask() + + def test_get_data(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + + assert_array_equal(linspace(0.0, 1.0, 101)**2, data_source.get_data()) + + def test_get_data_no_data(self): + data_source = FunctionDataSource() + + assert_array_equal(data_source.get_data(), array([], dtype=float)) + + def test_get_data_mask(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + + data, mask = data_source.get_data_mask() + assert_array_equal(data, linspace(0.0, 1.0, 101)**2) + assert_array_equal(mask, ones(shape=101, dtype=bool)) + + def test_bounds(self): + myfunc = lambda low, high: linspace(low, high, 100)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, high_setting=2.0) + + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0.0, 4.0)) + + @unittest.skip("default sort_order is ascending, which isn't right") + def test_bounds_non_monotone(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=-2.0, + high_setting=2.0) + + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0.0, 4.0)) + + def test_data_size(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=2.0) + + self.assertEqual(101, data_source.get_size()) From 530f2bbc6f8d2fee8253269d1a5916ff5b289cd0 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 21:09:48 +0000 Subject: [PATCH 14/30] Test range changes. --- chaco/tests/function_data_source_test_case.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/chaco/tests/function_data_source_test_case.py b/chaco/tests/function_data_source_test_case.py index 9faa2d4c9..548308833 100644 --- a/chaco/tests/function_data_source_test_case.py +++ b/chaco/tests/function_data_source_test_case.py @@ -37,6 +37,37 @@ def test_set_data(self): with self.assertRaises(RuntimeError): data_source.set_data(lambda low, high: linspace(low, high, 101)) + def test_range_high_changed(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.data_range.high_setting = 2.0 + + assert_array_equal(linspace(0.0, 2.0, 101)**2, data_source.get_data()) + + def test_range_low_changed(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.data_range.low_setting = -1.0 + + assert_array_equal(linspace(-1.0, 1.0, 101)**2, data_source.get_data()) + + def test_range_data_range_changed(self): + myfunc = lambda low, high: linspace(low, high, 101)**2 + data_source = FunctionDataSource(func=myfunc) + data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.data_range = DataRange1D(low_setting=-2.0, + high_setting=2.0) + + assert_array_equal(linspace(-2.0, 2.0, 101)**2, data_source.get_data()) + def test_set_mask(self): myfunc = lambda low, high: linspace(low, high, 101)**2 data_source = FunctionDataSource(func=myfunc) From 0479ba4d4cb888cef18d02766facc95913711b6e Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 21:54:57 +0000 Subject: [PATCH 15/30] Add tests for metadata. --- chaco/tests/arraydatasource_test_case.py | 21 +++++++++++++++++++ chaco/tests/function_data_source_test_case.py | 2 +- chaco/tests/image_data_test_case.py | 21 +++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index bb95430ed..7373eb98c 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -198,6 +198,27 @@ def test_data_size(self): data_source = ArrayDataSource(myarray) self.assertEqual(len(myarray), data_source.get_size()) + def test_metadata(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + + self.assertEqual(data_source.metadata, + {'annotations': [], 'selections': []}) + + def test_metadata_changed(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata = {'new_metadata': True} + + def test_metadata_items_changed(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata['new_metadata'] = True + class PointDataTestCase(unittest.TestCase): # Since PointData is mostly the same as ScalarData, the key things to diff --git a/chaco/tests/function_data_source_test_case.py b/chaco/tests/function_data_source_test_case.py index 548308833..2ab0e0cb3 100644 --- a/chaco/tests/function_data_source_test_case.py +++ b/chaco/tests/function_data_source_test_case.py @@ -4,7 +4,7 @@ import unittest2 as unittest -from numpy import array, linspace, nan, ones +from numpy import array, linspace, ones from numpy.testing import assert_array_equal import numpy as np diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py index 02b006fff..cfcc078a4 100644 --- a/chaco/tests/image_data_test_case.py +++ b/chaco/tests/image_data_test_case.py @@ -145,3 +145,24 @@ def test_fromfile_png_rgba(self): data_source = ImageData.fromfile(path) self.assertEqual(data_source.value_depth, 4) + + def test_metadata(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + self.assertEqual(data_source.metadata, + {'annotations': [], 'selections': []}) + + def test_metadata_changed(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata = {'new_metadata': True} + + def test_metadata_items_changed(self): + myarray = arange(15).reshape(5, 3, 1) + data_source = ImageData(data=myarray) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata['new_metadata'] = True From b7f87fb263e6fd5e64d4c1e4a93cf27fd37014a7 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 22:50:57 +0000 Subject: [PATCH 16/30] Modernise/improve grid data source tests. --- chaco/tests/grid_data_source_test_case.py | 87 +++++++++++++++-------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/chaco/tests/grid_data_source_test_case.py b/chaco/tests/grid_data_source_test_case.py index 544c91571..6e11b0ed5 100644 --- a/chaco/tests/grid_data_source_test_case.py +++ b/chaco/tests/grid_data_source_test_case.py @@ -1,57 +1,86 @@ -import unittest +import unittest2 as unittest from numpy import alltrue, array, ravel, isinf +from numpy.testing import assert_array_equal, assert_almost_equal from chaco.api import GridDataSource +from traits.testing.unittest_tools import UnittestTools -class GridDataSourceTestCase(unittest.TestCase): +class GridDataSourceTestCase(UnittestTools, unittest.TestCase): def test_empty(self): - ds = GridDataSource() - self.assert_(ds.sort_order == ('none', 'none')) - self.assert_(ds.index_dimension == 'image') - self.assert_(ds.value_dimension == 'scalar') - self.assert_(ds.metadata == {"selections":[], "annotations":[]}) - xdata, ydata = ds.get_data() - assert_ary_(xdata.get_data(), array([])) - assert_ary_(ydata.get_data(), array([])) - self.assert_(ds.get_bounds() == ((0,0),(0,0))) + data_source = GridDataSource() + self.assertEqual(data_source.sort_order, ('none', 'none')) + self.assertEqual(data_source.index_dimension, 'image') + self.assertEqual(data_source.value_dimension, 'scalar') + self.assertEqual(data_source.metadata, + {"selections":[], "annotations":[]}) + xdata, ydata = data_source.get_data() + assert_array_equal(xdata.get_data(), array([])) + assert_array_equal(ydata.get_data(), array([])) + self.assertEqual(data_source.get_bounds(), ((0,0),(0,0))) def test_init(self): test_xd = array([1,2,3]) test_yd = array([1.5, 0.5, -0.5, -1.5]) test_sort_order = ('ascending', 'descending') - ds = GridDataSource(xdata=test_xd, ydata=test_yd, - sort_order=test_sort_order) + data_source = GridDataSource(xdata=test_xd, ydata=test_yd, + sort_order=test_sort_order) - self.assert_(ds.sort_order == test_sort_order) - xd, yd = ds.get_data() - assert_ary_(xd.get_data(), test_xd) - assert_ary_(yd.get_data(), test_yd) - self.assert_(ds.get_bounds() == ((min(test_xd),min(test_yd)), - (max(test_xd),max(test_yd)))) + self.assertEqual(data_source.sort_order, test_sort_order) + xd, yd = data_source.get_data() + assert_array_equal(xd.get_data(), test_xd) + assert_array_equal(yd.get_data(), test_yd) + self.assertEqual(data_source.get_bounds(), + ((min(test_xd),min(test_yd)), + (max(test_xd),max(test_yd)))) def test_set_data(self): - ds = GridDataSource(xdata=array([1,2,3]), - ydata=array([1.5, 0.5, -0.5, -1.5]), - sort_order=('ascending', 'descending')) + data_source = GridDataSource(xdata=array([1,2,3]), + ydata=array([1.5, 0.5, -0.5, -1.5]), + sort_order=('ascending', 'descending')) test_xd = array([0,2,4]) test_yd = array([0,1,2,3,4,5]) test_sort_order = ('none', 'none') - ds.set_data(xdata=test_xd, ydata=test_yd, sort_order=('none', 'none')) + data_source.set_data(xdata=test_xd, ydata=test_yd, + sort_order=('none', 'none')) - self.assert_(ds.sort_order == test_sort_order) - xd, yd = ds.get_data() - assert_ary_(xd.get_data(), test_xd) - assert_ary_(yd.get_data(), test_yd) - self.assert_(ds.get_bounds() == ((min(test_xd),min(test_yd)), - (max(test_xd),max(test_yd)))) + self.assertEqual(data_source.sort_order, test_sort_order) + xd, yd = data_source.get_data() + assert_array_equal(xd.get_data(), test_xd) + assert_array_equal(yd.get_data(), test_yd) + self.assertEqual(data_source.get_bounds(), + ((min(test_xd),min(test_yd)), + (max(test_xd),max(test_yd)))) + def test_metadata(self): + data_source = GridDataSource(xdata=array([1,2,3]), + ydata=array([1.5, 0.5, -0.5, -1.5]), + sort_order=('ascending', 'descending')) + + self.assertEqual(data_source.metadata, + {'annotations': [], 'selections': []}) + + def test_metadata_changed(self): + data_source = GridDataSource(xdata=array([1,2,3]), + ydata=array([1.5, 0.5, -0.5, -1.5]), + sort_order=('ascending', 'descending')) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata = {'new_metadata': True} + + def test_metadata_items_changed(self): + data_source = GridDataSource(xdata=array([1,2,3]), + ydata=array([1.5, 0.5, -0.5, -1.5]), + sort_order=('ascending', 'descending')) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata['new_metadata'] = True From 60aed7c924f2c6a9953c44ae4ffa66690eacdfac Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 17 Dec 2014 23:32:53 +0000 Subject: [PATCH 17/30] Add multi array data source test case. --- chaco/tests/image_data_test_case.py | 1 + .../multi_array_data_source_test_case.py | 159 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 chaco/tests/multi_array_data_source_test_case.py diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py index cfcc078a4..7efd624cd 100644 --- a/chaco/tests/image_data_test_case.py +++ b/chaco/tests/image_data_test_case.py @@ -21,6 +21,7 @@ class ArrayDataTestCase(UnittestTools, unittest.TestCase): def test_init_defaults(self): data_source = ImageData() assert_array_equal(data_source.data, []) + # this isn't right - #self.assertEqual(data_source.value_dimension, "scalar") #self.assertEqual(data_source.image_dimension, "image") diff --git a/chaco/tests/multi_array_data_source_test_case.py b/chaco/tests/multi_array_data_source_test_case.py new file mode 100644 index 000000000..eadb9ff6a --- /dev/null +++ b/chaco/tests/multi_array_data_source_test_case.py @@ -0,0 +1,159 @@ +""" +Test of basic dataseries behavior. +""" + +import unittest2 as unittest + +from numpy import arange, array, allclose, empty, isnan, nan, ones +from numpy.testing import assert_array_equal +import numpy as np + +from chaco.api import MultiArrayDataSource +from traits.testing.unittest_tools import UnittestTools + + +class ArrayDataTestCase(UnittestTools, unittest.TestCase): + + def test_init_defaults(self): + data_source = MultiArrayDataSource() + assert_array_equal(data_source._data, empty(shape=(0, 1), dtype=float)) + # XXX this doesn't match AbstractDataSource's interface + self.assertEqual(data_source.value_dimension, 1) + self.assertEqual(data_source.sort_order, "ascending") + self.assertFalse(data_source.is_masked()) + + def test_basic_setup(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + assert_array_equal(myarray, data_source._data) + self.assertEqual(data_source.index_dimension, 0) + self.assertEqual(data_source.value_dimension, 1) + self.assertEqual(data_source.sort_order, "ascending") + self.assertFalse(data_source.is_masked()) + + def test_set_data(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + new_array = arange(0, 40, 2).reshape(10, 2) + + with self.assertTraitChanges(data_source, 'data_changed', count=1): + data_source.set_data(new_array) + + assert_array_equal(new_array, data_source._data) + self.assertEqual(data_source.get_bounds(), (0, 38)) + self.assertEqual(data_source.sort_order, "ascending") + + def test_get_data(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + assert_array_equal(myarray, data_source.get_data()) + + def test_get_data_no_data(self): + data_source = MultiArrayDataSource() + + assert_array_equal(data_source.get_data(), + empty(shape=(0, 1), dtype=float)) + + def test_get_data_mask(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + data, mask = data_source.get_data_mask() + assert_array_equal(data, myarray) + assert_array_equal(mask, ones(shape=(10, 2), dtype=bool)) + + def test_bounds(self): + # ascending + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 19)) + + # descending + myarray = arange(20)[::-1].reshape(10, 2) + data_source = MultiArrayDataSource(myarray, sort_order='descending') + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 19)) + + # no order + myarray = array([[12, 3], [0, 9], [2, 18], [3, 10]]) + data_source = MultiArrayDataSource(myarray, sort_order="none") + bounds = data_source.get_bounds() + self.assertEqual(bounds, (0, 18)) + + def test_bounds_value(self): + # ascending + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + bounds = data_source.get_bounds(value=0) + self.assertEqual(bounds, (0, 18)) + + # descending + myarray = arange(20)[::-1].reshape(10, 2) + data_source = MultiArrayDataSource(myarray, sort_order='descending') + bounds = data_source.get_bounds(value=0) + self.assertEqual(bounds, (1, 19)) + + # no order + myarray = array([[12, 3], [0, 9], [2, 18], [3, 10]]) + data_source = MultiArrayDataSource(myarray, sort_order="none") + bounds = data_source.get_bounds(value=0) + self.assertEqual(bounds, (0, 12)) + + def test_bounds_index(self): + # ascending + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + bounds = data_source.get_bounds(index=0) + self.assertEqual(bounds, (0, 1)) + + # descending + myarray = arange(20)[::-1].reshape(10, 2) + data_source = MultiArrayDataSource(myarray, sort_order='descending') + bounds = data_source.get_bounds(index=0) + self.assertEqual(bounds, (18, 19)) + + # no order + myarray = array([[12, 3], [0, 9], [2, 18], [3, 10]]) + data_source = MultiArrayDataSource(myarray, sort_order="none") + bounds = data_source.get_bounds(index=0) + self.assertEqual(bounds, (3, 12)) + + def test_bounds_empty(self): + data_source = MultiArrayDataSource() + bounds = data_source.get_bounds() + # XXX this is sort of inconsistent with test_bounds_all_nans() + self.assertEqual(bounds, (0, 0)) + + def test_bounds_all_nans(self): + myarray = empty((10,2)) + myarray[:, :] = nan + data_source = MultiArrayDataSource(myarray) + bounds = data_source.get_bounds() + self.assertTrue(isnan(bounds[0])) + self.assertTrue(isnan(bounds[1])) + + def test_metadata(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + self.assertEqual(data_source.metadata, + {'annotations': [], 'selections': []}) + + @unittest.skip('change handler missing from class') + def test_metadata_changed(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata = {'new_metadata': True} + + @unittest.skip('change handler missing from class') + def test_metadata_items_changed(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + with self.assertTraitChanges(data_source, 'metadata_changed', count=1): + data_source.metadata['new_metadata'] = True From 93772da9d0b9125f7095a339679438bad59a18af Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Thu, 18 Dec 2014 10:41:33 +0000 Subject: [PATCH 18/30] Tests around serialization methods and reverse maps. --- chaco/tests/arraydatasource_test_case.py | 67 ++++++++++++++++++- .../multi_array_data_source_test_case.py | 12 +++- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index 7373eb98c..55cff2af3 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -2,8 +2,9 @@ Test of basic dataseries behavior. """ -import unittest2 as unittest +import pickle +import unittest2 as unittest from numpy import arange, array, allclose, empty, isnan, nan, ones from numpy.testing import assert_array_equal import numpy as np @@ -198,6 +199,25 @@ def test_data_size(self): data_source = ArrayDataSource(myarray) self.assertEqual(len(myarray), data_source.get_size()) + def test_reverse_map(self): + # sort_order ascending + myarray = arange(10) + data_source = ArrayDataSource(myarray, sort_order='ascending') + + self.assertEqual(data_source.reverse_map(4.0), 4) + + # sort_order descending + myarray = arange(10)[::-1] + data_source = ArrayDataSource(myarray, sort_order='descending') + + self.assertEqual(data_source.reverse_map(4.0), 5) + + # sort_order none + myarray = array([12, 3, 0, 9, 2, 18, 3]) + data_source = ArrayDataSource(myarray, sort_order='ascending') + + self.assertEqual(data_source.reverse_map(3), None) + def test_metadata(self): myarray = arange(10) data_source = ArrayDataSource(myarray) @@ -219,6 +239,51 @@ def test_metadata_items_changed(self): with self.assertTraitChanges(data_source, 'metadata_changed', count=1): data_source.metadata['new_metadata'] = True + def test_serialization_state(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + + state = data_source.__getstate__() + self.assertTrue('value_dimension' not in state) + self.assertTrue('index_dimension' not in state) + self.assertTrue('persist_data' not in state) + + @unittest.skip("persist_data probably shouldn't be persisted") + def test_serialization_state_no_persist(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + data_source.persist_data = False + + state = data_source.__getstate__() + self.assertTrue('value_dimension' not in state) + self.assertTrue('index_dimension' not in state) + self.assertTrue('persist_data' not in state) + for key in {"_data", "_cached_mask", "_cached_bounds", + "_min_index", "_max_index"}: + self.assertTrue(key not in state) + + + @unittest.skip("I think this is just broken") + def test_serialization_post_load(self): + myarray = arange(10) + data_source = ArrayDataSource(myarray) + mymask = array([i % 2 for i in myarray], dtype=bool) + data_source.set_mask(mymask) + + pickled_data_source = pickle.dumps(data_source) + unpickled_data_source = pickle.loads(pickled_data_source) + unpickled_data_source._post_load() + + self.assertEqual(unpickled_data_source._cached_bounds, ()) + self.assertEqual(unpickled_data_source._cached_mask, None) + + assert_array_equal(data_source.get_data(), + unpickled_data_source.get_data()) + + mask = unpickled_data_source.get_data_mask()[1] + assert_array_equal(mask, ones(10)) + + class PointDataTestCase(unittest.TestCase): # Since PointData is mostly the same as ScalarData, the key things to diff --git a/chaco/tests/multi_array_data_source_test_case.py b/chaco/tests/multi_array_data_source_test_case.py index eadb9ff6a..2fc44d975 100644 --- a/chaco/tests/multi_array_data_source_test_case.py +++ b/chaco/tests/multi_array_data_source_test_case.py @@ -4,15 +4,14 @@ import unittest2 as unittest -from numpy import arange, array, allclose, empty, isnan, nan, ones +from numpy import arange, array, empty, isnan, nan, ones from numpy.testing import assert_array_equal -import numpy as np from chaco.api import MultiArrayDataSource from traits.testing.unittest_tools import UnittestTools -class ArrayDataTestCase(UnittestTools, unittest.TestCase): +class MultiArrayDataTestCase(UnittestTools, unittest.TestCase): def test_init_defaults(self): data_source = MultiArrayDataSource() @@ -27,6 +26,7 @@ def test_basic_setup(self): data_source = MultiArrayDataSource(myarray) assert_array_equal(myarray, data_source._data) + # XXX this doesn't match AbstractDataSource's interface self.assertEqual(data_source.index_dimension, 0) self.assertEqual(data_source.value_dimension, 1) self.assertEqual(data_source.sort_order, "ascending") @@ -50,6 +50,12 @@ def test_get_data(self): assert_array_equal(myarray, data_source.get_data()) + def test_get_data_axes(self): + myarray = arange(20).reshape(10, 2) + data_source = MultiArrayDataSource(myarray) + + assert_array_equal(arange(0, 20, 2), data_source.get_data(axes=0)) + def test_get_data_no_data(self): data_source = MultiArrayDataSource() From 9897383268af19516193735591216647e45784c1 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Thu, 18 Dec 2014 10:43:44 +0000 Subject: [PATCH 19/30] Update CHANGES.txt. --- CHANGES.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGES.txt b/CHANGES.txt index e8f9fc5b3..4934652c9 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -4,6 +4,10 @@ Chaco CHANGELOG Change summary since 4.5.0 +Enhancements + + * More comprehensive testing for AbstractDataSource subclasses (PR #244). + Fixes * Workaround RuntimeWarnings from nanmin and nanmax in ImageData.get_bounds @@ -20,7 +24,7 @@ New features/Improvements * Added perceptual colormaps by Matteo Niccoli, Dave Green and Kenneth Moreland. * Added `asynchronous_updates.py` demo that shows a pattern for generating expensive plots while keeping the interface responsive (PR#170). - * Speeded up by 10x the data mappers of the `GridMapper` class (mapping of 2D data + * Speeded up by 10x the data mappers of the `GridMapper` class (mapping of 2D data to/from screen space). From f25cb9f3475841a1e54e4619a224a18864e57f41 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Thu, 18 Dec 2014 10:49:03 +0000 Subject: [PATCH 20/30] Flake8. --- chaco/tests/arraydatasource_test_case.py | 7 +++---- chaco/tests/function_data_source_test_case.py | 1 - chaco/tests/multi_array_data_source_test_case.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index 55cff2af3..df4fbb787 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -262,7 +262,6 @@ def test_serialization_state_no_persist(self): "_min_index", "_max_index"}: self.assertTrue(key not in state) - @unittest.skip("I think this is just broken") def test_serialization_post_load(self): myarray = arange(10) @@ -284,7 +283,6 @@ def test_serialization_post_load(self): assert_array_equal(mask, ones(10)) - class PointDataTestCase(unittest.TestCase): # Since PointData is mostly the same as ScalarData, the key things to # test are functionality that use _compute_bounds() and reverse_map(). @@ -294,16 +292,17 @@ def create_array(self): def test_basic_set_get(self): myarray = self.create_array() pd = PointDataSource(myarray) - self.assertTrue(allclose(myarray,pd._data)) + self.assertTrue(allclose(myarray, pd._data)) self.assert_(pd.value_dimension == "point") return def test_bounds(self): myarray = self.create_array() pd = PointDataSource(myarray) - self.assertEqual(pd.get_bounds(),((0,0), (9,90))) + self.assertEqual(pd.get_bounds(), ((0, 0), (9, 90))) return + if __name__ == '__main__': import nose nose.run() diff --git a/chaco/tests/function_data_source_test_case.py b/chaco/tests/function_data_source_test_case.py index 2ab0e0cb3..29e04f315 100644 --- a/chaco/tests/function_data_source_test_case.py +++ b/chaco/tests/function_data_source_test_case.py @@ -6,7 +6,6 @@ from numpy import array, linspace, ones from numpy.testing import assert_array_equal -import numpy as np from chaco.api import DataRange1D from chaco.function_data_source import FunctionDataSource diff --git a/chaco/tests/multi_array_data_source_test_case.py b/chaco/tests/multi_array_data_source_test_case.py index 2fc44d975..7e448f5a4 100644 --- a/chaco/tests/multi_array_data_source_test_case.py +++ b/chaco/tests/multi_array_data_source_test_case.py @@ -134,7 +134,7 @@ def test_bounds_empty(self): self.assertEqual(bounds, (0, 0)) def test_bounds_all_nans(self): - myarray = empty((10,2)) + myarray = empty((10, 2)) myarray[:, :] = nan data_source = MultiArrayDataSource(myarray) bounds = data_source.get_bounds() From bfb0883f5b60b9e76b287f2d945981a83ca4d536 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Thu, 18 Dec 2014 13:16:13 +0000 Subject: [PATCH 21/30] Remove set literal from test for 2.6 comptibility. --- chaco/tests/arraydatasource_test_case.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index df4fbb787..26f7210ee 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -258,8 +258,8 @@ def test_serialization_state_no_persist(self): self.assertTrue('value_dimension' not in state) self.assertTrue('index_dimension' not in state) self.assertTrue('persist_data' not in state) - for key in {"_data", "_cached_mask", "_cached_bounds", - "_min_index", "_max_index"}: + for key in ["_data", "_cached_mask", "_cached_bounds", "_min_index", + "_max_index"]: self.assertTrue(key not in state) @unittest.skip("I think this is just broken") From 2b25fda416d95d40c8170817ca1622d2b23f47bf Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Fri, 19 Dec 2014 19:06:50 +0000 Subject: [PATCH 22/30] Add setup to array data source. --- chaco/tests/arraydatasource_test_case.py | 154 ++++++++++------------- 1 file changed, 63 insertions(+), 91 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index 26f7210ee..c4a47423a 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -1,5 +1,5 @@ """ -Test of basic dataseries behavior. +Test of ArrayDataSource behavior. """ import pickle @@ -13,7 +13,12 @@ from traits.testing.unittest_tools import UnittestTools -class ArrayDataTestCase(UnittestTools, unittest.TestCase): +class ArrayDataSourceTest(UnittestTools, unittest.TestCase): + + def setUp(self): + self.myarray = arange(10) + self.mymask = array([i % 2 for i in self.myarray], dtype=bool) + self.data_source = ArrayDataSource(self.myarray) def test_init_defaults(self): data_source = ArrayDataSource() @@ -23,70 +28,58 @@ def test_init_defaults(self): self.assertFalse(data_source.is_masked()) def test_basic_setup(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - assert_array_equal(myarray, data_source._data) - self.assertEqual(data_source.value_dimension, "scalar") - self.assertEqual(data_source.sort_order, "none") - self.assertFalse(data_source.is_masked()) + assert_array_equal(self.myarray, self.data_source._data) + self.assertEqual(self.data_source.value_dimension, "scalar") + self.assertEqual(self.data_source.sort_order, "none") + self.assertFalse(self.data_source.is_masked()) def test_set_data(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) new_array = arange(0, 20, 2) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.set_data(new_array) + with self.assertTraitChanges(self.data_source, 'data_changed', + count=1): + self.data_source.set_data(new_array) - assert_array_equal(new_array, data_source._data) - self.assertEqual(data_source.get_bounds(), (0, 18)) - self.assertEqual(data_source.sort_order, "none") + assert_array_equal(new_array, self.data_source._data) + self.assertEqual(self.data_source.get_bounds(), (0, 18)) + self.assertEqual(self.data_source.sort_order, "none") def test_set_data_ordered(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) new_array = arange(20, 0, -2) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.set_data(new_array, sort_order='descending') + with self.assertTraitChanges(self.data_source, 'data_changed', + count=1): + self.data_source.set_data(new_array, sort_order='descending') - assert_array_equal(new_array, data_source._data) - self.assertEqual(data_source.get_bounds(), (2, 20)) - self.assertEqual(data_source.sort_order, "descending") + assert_array_equal(new_array, self.data_source._data) + self.assertEqual(self.data_source.get_bounds(), (2, 20)) + self.assertEqual(self.data_source.sort_order, "descending") def test_set_mask(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - mymask = array([i % 2 for i in myarray], dtype=bool) - - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.set_mask(mymask) + with self.assertTraitChanges(self.data_source, 'data_changed', + count=1): + self.data_source.set_mask(self.mymask) - assert_array_equal(myarray, data_source._data) - assert_array_equal(mymask, data_source._cached_mask) - self.assertTrue(data_source.is_masked()) - self.assertEqual(data_source.get_bounds(), (0, 9)) + assert_array_equal(self.myarray, self.data_source._data) + assert_array_equal(self.mymask, self.data_source._cached_mask) + self.assertTrue(self.data_source.is_masked()) + self.assertEqual(self.data_source.get_bounds(), (0, 9)) def test_remove_mask(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - mymask = array([i % 2 for i in myarray], dtype=bool) - data_source.set_mask(mymask) - self.assertTrue(data_source.is_masked()) + self.data_source.set_mask(self.mymask) + self.assertTrue(self.data_source.is_masked()) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.remove_mask() + with self.assertTraitChanges(self.data_source, 'data_changed', + count=1): + self.data_source.remove_mask() - assert_array_equal(myarray, data_source._data) - self.assertIsNone(data_source._cached_mask, None) - self.assertFalse(data_source.is_masked()) - self.assertEqual(data_source.get_bounds(), (0, 9)) + assert_array_equal(self.myarray, self.data_source._data) + self.assertIsNone(self.data_source._cached_mask, None) + self.assertFalse(self.data_source.is_masked()) + self.assertEqual(self.data_source.get_bounds(), (0, 9)) def test_get_data(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - - assert_array_equal(myarray, data_source.get_data()) + assert_array_equal(self.myarray, self.data_source.get_data()) def test_get_data_no_data(self): data_source = ArrayDataSource(None) @@ -95,13 +88,11 @@ def test_get_data_no_data(self): assert_array_equal(data_source.get_data(), 0.0) def test_get_data_mask(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - mymask = array([i % 2 for i in myarray], dtype=bool) - data_source.set_mask(mymask) + self.data_source.set_mask(self.mymask) - data, mask = data_source.get_data_mask() - assert_array_equal(data, myarray) + data, mask = self.data_source.get_data_mask() + assert_array_equal(data, self.myarray) + assert_array_equal(mask, self.mymask) @unittest.skip('get_data_mask() fails in this case') def test_get_data_mask_no_data(self): @@ -113,18 +104,13 @@ def test_get_data_mask_no_data(self): assert_array_equal(data, True) def test_get_data_mask_no_mask(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - - data, mask = data_source.get_data_mask() - assert_array_equal(data, myarray) + data, mask = self.data_source.get_data_mask() + assert_array_equal(data, self.myarray) assert_array_equal(mask, ones(shape=10, dtype=bool)) def test_bounds(self): # ascending - myarray = arange(10) - data_source = ArrayDataSource(myarray, sort_order="ascending") - bounds = data_source.get_bounds() + bounds = self.data_source.get_bounds() self.assertEqual(bounds, (0, 9)) # descending @@ -214,47 +200,36 @@ def test_reverse_map(self): # sort_order none myarray = array([12, 3, 0, 9, 2, 18, 3]) - data_source = ArrayDataSource(myarray, sort_order='ascending') + data_source = ArrayDataSource(myarray, sort_order='none') - self.assertEqual(data_source.reverse_map(3), None) + with self.assertRaises(NotImplementedError): + data_source.reverse_map(3) def test_metadata(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - - self.assertEqual(data_source.metadata, + self.assertEqual(self.data_source.metadata, {'annotations': [], 'selections': []}) def test_metadata_changed(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata = {'new_metadata': True} + with self.assertTraitChanges(self.data_source, 'metadata_changed', + count=1): + self.data_source.metadata = {'new_metadata': True} def test_metadata_items_changed(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata['new_metadata'] = True + with self.assertTraitChanges(self.data_source, 'metadata_changed', + count=1): + self.data_source.metadata['new_metadata'] = True def test_serialization_state(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - - state = data_source.__getstate__() + state = self.data_source.__getstate__() self.assertTrue('value_dimension' not in state) self.assertTrue('index_dimension' not in state) self.assertTrue('persist_data' not in state) @unittest.skip("persist_data probably shouldn't be persisted") def test_serialization_state_no_persist(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - data_source.persist_data = False + self.data_source.persist_data = False - state = data_source.__getstate__() + state = self.data_source.__getstate__() self.assertTrue('value_dimension' not in state) self.assertTrue('index_dimension' not in state) self.assertTrue('persist_data' not in state) @@ -264,19 +239,16 @@ def test_serialization_state_no_persist(self): @unittest.skip("I think this is just broken") def test_serialization_post_load(self): - myarray = arange(10) - data_source = ArrayDataSource(myarray) - mymask = array([i % 2 for i in myarray], dtype=bool) - data_source.set_mask(mymask) + self.data_source.set_mask(self.mymask) - pickled_data_source = pickle.dumps(data_source) + pickled_data_source = pickle.dumps(self.data_source) unpickled_data_source = pickle.loads(pickled_data_source) unpickled_data_source._post_load() self.assertEqual(unpickled_data_source._cached_bounds, ()) self.assertEqual(unpickled_data_source._cached_mask, None) - assert_array_equal(data_source.get_data(), + assert_array_equal(self.data_source.get_data(), unpickled_data_source.get_data()) mask = unpickled_data_source.get_data_mask()[1] From d743dd93ca9fddc29dfea79e5196cfcd030d50d7 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Tue, 23 Dec 2014 12:47:50 +0000 Subject: [PATCH 23/30] Add setUp method to multiarray data source test case. --- .../multi_array_data_source_test_case.py | 77 +++++++------------ 1 file changed, 26 insertions(+), 51 deletions(-) diff --git a/chaco/tests/multi_array_data_source_test_case.py b/chaco/tests/multi_array_data_source_test_case.py index 7e448f5a4..566b8229b 100644 --- a/chaco/tests/multi_array_data_source_test_case.py +++ b/chaco/tests/multi_array_data_source_test_case.py @@ -13,6 +13,10 @@ class MultiArrayDataTestCase(UnittestTools, unittest.TestCase): + def setUp(self): + self.myarray = arange(20).reshape(10, 2) + self.data_source = MultiArrayDataSource(self.myarray) + def test_init_defaults(self): data_source = MultiArrayDataSource() assert_array_equal(data_source._data, empty(shape=(0, 1), dtype=float)) @@ -22,39 +26,28 @@ def test_init_defaults(self): self.assertFalse(data_source.is_masked()) def test_basic_setup(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - assert_array_equal(myarray, data_source._data) + assert_array_equal(self.myarray, self.data_source._data) # XXX this doesn't match AbstractDataSource's interface - self.assertEqual(data_source.index_dimension, 0) - self.assertEqual(data_source.value_dimension, 1) - self.assertEqual(data_source.sort_order, "ascending") - self.assertFalse(data_source.is_masked()) + self.assertEqual(self.data_source.index_dimension, 0) + self.assertEqual(self.data_source.value_dimension, 1) + self.assertEqual(self.data_source.sort_order, "ascending") + self.assertFalse(self.data_source.is_masked()) def test_set_data(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) new_array = arange(0, 40, 2).reshape(10, 2) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.set_data(new_array) + with self.assertTraitChanges(self.data_source, 'data_changed', count=1): + self.data_source.set_data(new_array) - assert_array_equal(new_array, data_source._data) - self.assertEqual(data_source.get_bounds(), (0, 38)) - self.assertEqual(data_source.sort_order, "ascending") + assert_array_equal(new_array, self.data_source._data) + self.assertEqual(self.data_source.get_bounds(), (0, 38)) + self.assertEqual(self.data_source.sort_order, "ascending") def test_get_data(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - assert_array_equal(myarray, data_source.get_data()) + assert_array_equal(self.myarray, self.data_source.get_data()) def test_get_data_axes(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - assert_array_equal(arange(0, 20, 2), data_source.get_data(axes=0)) + assert_array_equal(arange(0, 20, 2), self.data_source.get_data(axes=0)) def test_get_data_no_data(self): data_source = MultiArrayDataSource() @@ -63,18 +56,13 @@ def test_get_data_no_data(self): empty(shape=(0, 1), dtype=float)) def test_get_data_mask(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - data, mask = data_source.get_data_mask() - assert_array_equal(data, myarray) + data, mask = self.data_source.get_data_mask() + assert_array_equal(data, self.myarray) assert_array_equal(mask, ones(shape=(10, 2), dtype=bool)) def test_bounds(self): # ascending - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - bounds = data_source.get_bounds() + bounds = self.data_source.get_bounds() self.assertEqual(bounds, (0, 19)) # descending @@ -91,9 +79,7 @@ def test_bounds(self): def test_bounds_value(self): # ascending - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - bounds = data_source.get_bounds(value=0) + bounds = self.data_source.get_bounds(value=0) self.assertEqual(bounds, (0, 18)) # descending @@ -110,9 +96,7 @@ def test_bounds_value(self): def test_bounds_index(self): # ascending - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - bounds = data_source.get_bounds(index=0) + bounds = self.data_source.get_bounds(index=0) self.assertEqual(bounds, (0, 1)) # descending @@ -142,24 +126,15 @@ def test_bounds_all_nans(self): self.assertTrue(isnan(bounds[1])) def test_metadata(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - self.assertEqual(data_source.metadata, + self.assertEqual(self.data_source.metadata, {'annotations': [], 'selections': []}) @unittest.skip('change handler missing from class') def test_metadata_changed(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata = {'new_metadata': True} + with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1): + self.data_source.metadata = {'new_metadata': True} @unittest.skip('change handler missing from class') def test_metadata_items_changed(self): - myarray = arange(20).reshape(10, 2) - data_source = MultiArrayDataSource(myarray) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata['new_metadata'] = True + with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1): + self.data_source.metadata['new_metadata'] = True From bc02978bccd8bfd9a41be5c9ceac4518a2137d17 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Tue, 23 Dec 2014 13:15:43 +0000 Subject: [PATCH 24/30] Improvements based on suggestions in PR. --- chaco/tests/arraydatasource_test_case.py | 24 ++-- chaco/tests/function_data_source_test_case.py | 112 ++++++++---------- chaco/tests/grid_data_source_test_case.py | 74 ++++-------- chaco/tests/image_data_test_case.py | 80 ++++--------- .../multi_array_data_source_test_case.py | 2 +- 5 files changed, 114 insertions(+), 178 deletions(-) diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index c4a47423a..341d9c9f0 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -1,5 +1,5 @@ """ -Test of ArrayDataSource behavior. +Tests of ArrayDataSource behavior. """ import pickle @@ -13,7 +13,7 @@ from traits.testing.unittest_tools import UnittestTools -class ArrayDataSourceTest(UnittestTools, unittest.TestCase): +class ArrayDataSourceTestCase(UnittestTools, unittest.TestCase): def setUp(self): self.myarray = arange(10) @@ -24,8 +24,10 @@ def test_init_defaults(self): data_source = ArrayDataSource() assert_array_equal(data_source._data, []) self.assertEqual(data_source.value_dimension, "scalar") + self.assertEqual(data_source.index_dimension, "scalar") self.assertEqual(data_source.sort_order, "none") self.assertFalse(data_source.is_masked()) + self.assertEqual(data_source.persist_data, True) def test_basic_setup(self): assert_array_equal(self.myarray, self.data_source._data) @@ -84,7 +86,6 @@ def test_get_data(self): def test_get_data_no_data(self): data_source = ArrayDataSource(None) - # XXX A _scalar_? Not array([]) or None? assert_array_equal(data_source.get_data(), 0.0) def test_get_data_mask(self): @@ -99,9 +100,8 @@ def test_get_data_mask_no_data(self): data_source = ArrayDataSource(None) data, mask = data_source.get_data_mask() - # XXX this is what I would expect, given get_data() behaviour assert_array_equal(data, 0.0) - assert_array_equal(data, True) + assert_array_equal(mask, True) def test_get_data_mask_no_mask(self): data, mask = self.data_source.get_data_mask() @@ -221,21 +221,21 @@ def test_metadata_items_changed(self): def test_serialization_state(self): state = self.data_source.__getstate__() - self.assertTrue('value_dimension' not in state) - self.assertTrue('index_dimension' not in state) - self.assertTrue('persist_data' not in state) + self.assertNotIn('value_dimension', state) + self.assertNotIn('index_dimension', state) + self.assertNotIn('persist_data', state) @unittest.skip("persist_data probably shouldn't be persisted") def test_serialization_state_no_persist(self): self.data_source.persist_data = False state = self.data_source.__getstate__() - self.assertTrue('value_dimension' not in state) - self.assertTrue('index_dimension' not in state) - self.assertTrue('persist_data' not in state) + self.assertNotIn('value_dimension', state) + self.assertNotIn('index_dimension', state) + self.assertNotIn('persist_data', state) for key in ["_data", "_cached_mask", "_cached_bounds", "_min_index", "_max_index"]: - self.assertTrue(key not in state) + self.assertIn(key, state) @unittest.skip("I think this is just broken") def test_serialization_post_load(self): diff --git a/chaco/tests/function_data_source_test_case.py b/chaco/tests/function_data_source_test_case.py index 29e04f315..d4717600d 100644 --- a/chaco/tests/function_data_source_test_case.py +++ b/chaco/tests/function_data_source_test_case.py @@ -1,5 +1,5 @@ """ -Test of basic dataseries behavior. +Test of FunctionDataSource behavior. """ import unittest2 as unittest @@ -14,6 +14,11 @@ class FunctionDataSourceTestCase(UnittestTools, unittest.TestCase): + def setUp(self): + self.myfunc = lambda low, high: linspace(low, high, 101)**2 + self.data_source = FunctionDataSource(func=self.myfunc) + + def test_init_defaults(self): data_source = FunctionDataSource() assert_array_equal(data_source._data, []) @@ -22,109 +27,94 @@ def test_init_defaults(self): self.assertFalse(data_source.is_masked()) def test_basic_setup(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - assert_array_equal(myfunc, data_source.func) - self.assertEqual(data_source.value_dimension, "scalar") - self.assertEqual(data_source.sort_order, "ascending") - self.assertFalse(data_source.is_masked()) + assert_array_equal(self.myfunc, self.data_source.func) + self.assertEqual(self.data_source.value_dimension, "scalar") + self.assertEqual(self.data_source.sort_order, "ascending") + self.assertFalse(self.data_source.is_masked()) def test_set_data(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - with self.assertRaises(RuntimeError): - data_source.set_data(lambda low, high: linspace(low, high, 101)) + self.data_source.set_data( + lambda low, high: linspace(low, high, 101)) def test_range_high_changed(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=1.0) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.data_range.high_setting = 2.0 + with self.assertTraitChanges(self.data_source, 'data_changed', count=1): + self.data_source.data_range.high_setting = 2.0 - assert_array_equal(linspace(0.0, 2.0, 101)**2, data_source.get_data()) + assert_array_equal(linspace(0.0, 2.0, 101)**2, + self.data_source.get_data()) def test_range_low_changed(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=1.0) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.data_range.low_setting = -1.0 + with self.assertTraitChanges(self.data_source, 'data_changed', count=1): + self.data_source.data_range.low_setting = -1.0 - assert_array_equal(linspace(-1.0, 1.0, 101)**2, data_source.get_data()) + assert_array_equal(linspace(-1.0, 1.0, 101)**2, + self.data_source.get_data()) def test_range_data_range_changed(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=1.0) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.data_range = DataRange1D(low_setting=-2.0, - high_setting=2.0) + with self.assertTraitChanges(self.data_source, 'data_changed', count=1): + self.data_source.data_range = DataRange1D(low_setting=-2.0, + high_setting=2.0) - assert_array_equal(linspace(-2.0, 2.0, 101)**2, data_source.get_data()) + assert_array_equal(linspace(-2.0, 2.0, 101)**2, + self.data_source.get_data()) def test_set_mask(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) mymask = array([i % 2 for i in xrange(101)], dtype=bool) with self.assertRaises(NotImplementedError): - data_source.set_mask(mymask) + self.data_source.set_mask(mymask) def test_remove_mask(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - with self.assertRaises(NotImplementedError): - data_source.remove_mask() + self.data_source.remove_mask() def test_get_data(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=1.0) - assert_array_equal(linspace(0.0, 1.0, 101)**2, data_source.get_data()) + assert_array_equal(linspace(0.0, 1.0, 101)**2, + self.data_source.get_data()) def test_get_data_no_data(self): - data_source = FunctionDataSource() + self.data_source = FunctionDataSource() - assert_array_equal(data_source.get_data(), array([], dtype=float)) + assert_array_equal(self.data_source.get_data(), array([], dtype=float)) def test_get_data_mask(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, high_setting=1.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=1.0) - data, mask = data_source.get_data_mask() + data, mask = self.data_source.get_data_mask() assert_array_equal(data, linspace(0.0, 1.0, 101)**2) assert_array_equal(mask, ones(shape=101, dtype=bool)) def test_bounds(self): - myfunc = lambda low, high: linspace(low, high, 100)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, high_setting=2.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=2.0) - bounds = data_source.get_bounds() + bounds = self.data_source.get_bounds() self.assertEqual(bounds, (0.0, 4.0)) @unittest.skip("default sort_order is ascending, which isn't right") def test_bounds_non_monotone(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=-2.0, - high_setting=2.0) + self.data_source.data_range = DataRange1D(low_setting=-2.0, + high_setting=2.0) - bounds = data_source.get_bounds() + bounds = self.data_source.get_bounds() self.assertEqual(bounds, (0.0, 4.0)) def test_data_size(self): - myfunc = lambda low, high: linspace(low, high, 101)**2 - data_source = FunctionDataSource(func=myfunc) - data_source.data_range = DataRange1D(low_setting=0.0, - high_setting=2.0) + self.data_source.data_range = DataRange1D(low_setting=0.0, + high_setting=2.0) - self.assertEqual(101, data_source.get_size()) + self.assertEqual(101, self.data_source.get_size()) diff --git a/chaco/tests/grid_data_source_test_case.py b/chaco/tests/grid_data_source_test_case.py index 6e11b0ed5..32425c123 100644 --- a/chaco/tests/grid_data_source_test_case.py +++ b/chaco/tests/grid_data_source_test_case.py @@ -1,8 +1,11 @@ +""" +Tests of GridDataSource behavior. +""" import unittest2 as unittest -from numpy import alltrue, array, ravel, isinf -from numpy.testing import assert_array_equal, assert_almost_equal +from numpy import array +from numpy.testing import assert_array_equal from chaco.api import GridDataSource from traits.testing.unittest_tools import UnittestTools @@ -10,6 +13,12 @@ class GridDataSourceTestCase(UnittestTools, unittest.TestCase): + def setUp(self): + self.data_source = GridDataSource( + xdata=array([1, 2, 3]), + ydata=array([1.5, 0.5, -0.5, -1.5]), + sort_order=('ascending', 'descending')) + def test_empty(self): data_source = GridDataSource() self.assertEqual(data_source.sort_order, ('none', 'none')) @@ -23,81 +32,46 @@ def test_empty(self): self.assertEqual(data_source.get_bounds(), ((0,0),(0,0))) def test_init(self): - test_xd = array([1,2,3]) + test_xd = array([1, 2, 3]) test_yd = array([1.5, 0.5, -0.5, -1.5]) test_sort_order = ('ascending', 'descending') - data_source = GridDataSource(xdata=test_xd, ydata=test_yd, - sort_order=test_sort_order) - - self.assertEqual(data_source.sort_order, test_sort_order) - xd, yd = data_source.get_data() + self.assertEqual(self.data_source.sort_order, test_sort_order) + xd, yd = self.data_source.get_data() assert_array_equal(xd.get_data(), test_xd) assert_array_equal(yd.get_data(), test_yd) - self.assertEqual(data_source.get_bounds(), + self.assertEqual(self.data_source.get_bounds(), ((min(test_xd),min(test_yd)), (max(test_xd),max(test_yd)))) def test_set_data(self): - data_source = GridDataSource(xdata=array([1,2,3]), - ydata=array([1.5, 0.5, -0.5, -1.5]), - sort_order=('ascending', 'descending')) test_xd = array([0,2,4]) test_yd = array([0,1,2,3,4,5]) test_sort_order = ('none', 'none') - data_source.set_data(xdata=test_xd, ydata=test_yd, + self.data_source.set_data(xdata=test_xd, ydata=test_yd, sort_order=('none', 'none')) - self.assertEqual(data_source.sort_order, test_sort_order) - xd, yd = data_source.get_data() + self.assertEqual(self.data_source.sort_order, test_sort_order) + xd, yd = self.data_source.get_data() assert_array_equal(xd.get_data(), test_xd) assert_array_equal(yd.get_data(), test_yd) - self.assertEqual(data_source.get_bounds(), + self.assertEqual(self.data_source.get_bounds(), ((min(test_xd),min(test_yd)), (max(test_xd),max(test_yd)))) def test_metadata(self): - data_source = GridDataSource(xdata=array([1,2,3]), - ydata=array([1.5, 0.5, -0.5, -1.5]), - sort_order=('ascending', 'descending')) - - self.assertEqual(data_source.metadata, + self.assertEqual(self.data_source.metadata, {'annotations': [], 'selections': []}) def test_metadata_changed(self): - data_source = GridDataSource(xdata=array([1,2,3]), - ydata=array([1.5, 0.5, -0.5, -1.5]), - sort_order=('ascending', 'descending')) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata = {'new_metadata': True} + with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1): + self.data_source.metadata = {'new_metadata': True} def test_metadata_items_changed(self): - data_source = GridDataSource(xdata=array([1,2,3]), - ydata=array([1.5, 0.5, -0.5, -1.5]), - sort_order=('ascending', 'descending')) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata['new_metadata'] = True - - - -def assert_close_(desired,actual): - diff_allowed = 1e-5 - diff = abs(ravel(actual) - ravel(desired)) - for d in diff: - if not isinf(d): - assert alltrue(d <= diff_allowed) - return - -def assert_ary_(desired, actual): - if (desired == 'auto'): - assert actual == 'auto' - for d in range(len(desired)): - assert desired[d] == actual[d] - return + with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1): + self.data_source.metadata['new_metadata'] = True if __name__ == '__main__': diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py index 7efd624cd..232861ce3 100644 --- a/chaco/tests/image_data_test_case.py +++ b/chaco/tests/image_data_test_case.py @@ -1,5 +1,5 @@ """ -Test of basic dataseries behavior. +Test of ImageData behavior. """ import os @@ -16,7 +16,11 @@ data_dir = resource_filename('chaco.tests', 'data') -class ArrayDataTestCase(UnittestTools, unittest.TestCase): +class ImageDataTestCase(UnittestTools, unittest.TestCase): + + def setUp(self): + self.myarray = arange(15).reshape(5, 3, 1) + self.data_source = ImageData(data=self.myarray) def test_init_defaults(self): data_source = ImageData() @@ -27,28 +31,21 @@ def test_init_defaults(self): #self.assertEqual(data_source.image_dimension, "image") def test_basic_setup(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - assert_array_equal(myarray, data_source.data) - #self.assertEqual(data_source.value_dimension, "scalar") - self.assertFalse(data_source.is_masked()) + assert_array_equal(self.myarray, self.data_source.data) + #self.assertEqual(self.data_source.value_dimension, "scalar") + self.assertFalse(self.data_source.is_masked()) def test_set_data(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) new_array = arange(0, 30, 2).reshape(5, 3, 1) - with self.assertTraitChanges(data_source, 'data_changed', count=1): - data_source.set_data(new_array) + with self.assertTraitChanges(self.data_source, 'data_changed', count=1): + self.data_source.set_data(new_array) - assert_array_equal(new_array, data_source.data) - self.assertEqual(data_source.get_bounds(), (0, 28)) + assert_array_equal(new_array, self.data_source.data) + self.assertEqual(self.data_source.get_bounds(), (0, 28)) def test_get_data(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - assert_array_equal(myarray, data_source.get_data()) + assert_array_equal(self.myarray, self.data_source.get_data()) def test_get_data_no_data(self): data_source = ImageData() @@ -59,15 +56,12 @@ def test_get_data_transposed(self): myarray = arange(15).reshape(5, 3, 1) data_source = ImageData(data=myarray, transposed=True) - assert_array_equal(swapaxes(myarray, 0, 1), data_source.get_data()) + assert_array_equal(swapaxes(myarray, 0, 1), self.data_source.get_data()) def test_get_data_mask(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - # XXX this is probably not the right thing with self.assertRaises(NotImplementedError): - data, mask = data_source.get_data_mask() + data, mask = self.data_source.get_data_mask() def test_get_data_mask_no_data(self): data_source = ImageData() @@ -77,9 +71,7 @@ def test_get_data_mask_no_data(self): data, mask = data_source.get_data_mask() def test_bounds(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - bounds = data_source.get_bounds() + bounds = self.data_source.get_bounds() self.assertEqual(bounds, (0, 14)) @unittest.skip('test_bounds_empty() fails in this case') @@ -89,19 +81,14 @@ def test_bounds_empty(self): self.assertEqual(bounds, (0, 0)) def test_data_size(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - self.assertEqual(15, data_source.get_size()) + self.assertEqual(15, self.data_source.get_size()) def test_data_size_no_data(self): data_source = ImageData() self.assertEqual(0, data_source.get_size()) def test_get_width(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - self.assertEqual(3, data_source.get_width()) + self.assertEqual(3, self.data_source.get_width()) def test_get_width_transposed(self): myarray = arange(15).reshape(5, 3) @@ -110,10 +97,7 @@ def test_get_width_transposed(self): self.assertEqual(5, data_source.get_width()) def test_get_height(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - self.assertEqual(5, data_source.get_height()) + self.assertEqual(5, self.data_source.get_height()) def test_get_height_transposed(self): myarray = arange(15).reshape(5, 3, 1) @@ -122,10 +106,7 @@ def test_get_height_transposed(self): self.assertEqual(3, data_source.get_height()) def test_array_bounds(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - self.assertEqual(((0, 3), (0, 5)), data_source.get_array_bounds()) + self.assertEqual(((0, 3), (0, 5)), self.data_source.get_array_bounds()) def test_array_bounds_transposed(self): myarray = arange(15).reshape(5, 3, 1) @@ -148,22 +129,13 @@ def test_fromfile_png_rgba(self): self.assertEqual(data_source.value_depth, 4) def test_metadata(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - self.assertEqual(data_source.metadata, + self.assertEqual(self.data_source.metadata, {'annotations': [], 'selections': []}) def test_metadata_changed(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata = {'new_metadata': True} + with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1): + self.data_source.metadata = {'new_metadata': True} def test_metadata_items_changed(self): - myarray = arange(15).reshape(5, 3, 1) - data_source = ImageData(data=myarray) - - with self.assertTraitChanges(data_source, 'metadata_changed', count=1): - data_source.metadata['new_metadata'] = True + with self.assertTraitChanges(self.data_source, 'metadata_changed', count=1): + self.data_source.metadata['new_metadata'] = True diff --git a/chaco/tests/multi_array_data_source_test_case.py b/chaco/tests/multi_array_data_source_test_case.py index 566b8229b..28e62ca90 100644 --- a/chaco/tests/multi_array_data_source_test_case.py +++ b/chaco/tests/multi_array_data_source_test_case.py @@ -1,5 +1,5 @@ """ -Test of basic dataseries behavior. +Test of MultiArrayDataSource behavior. """ import unittest2 as unittest From 6ba47ecdbbb9e97e9458e2b5d4e56c1410417049 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Tue, 23 Dec 2014 14:36:26 +0000 Subject: [PATCH 25/30] Fix typo. --- chaco/tests/image_data_test_case.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chaco/tests/image_data_test_case.py b/chaco/tests/image_data_test_case.py index 232861ce3..3177ebcea 100644 --- a/chaco/tests/image_data_test_case.py +++ b/chaco/tests/image_data_test_case.py @@ -56,7 +56,7 @@ def test_get_data_transposed(self): myarray = arange(15).reshape(5, 3, 1) data_source = ImageData(data=myarray, transposed=True) - assert_array_equal(swapaxes(myarray, 0, 1), self.data_source.get_data()) + assert_array_equal(swapaxes(myarray, 0, 1), data_source.get_data()) def test_get_data_mask(self): # XXX this is probably not the right thing From 717aea8ac4eb83d03e1a6f35bb735c9d93cf2d7a Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Mon, 29 Dec 2014 18:52:31 +0000 Subject: [PATCH 26/30] Add a BaseArrayDataSource and tests, plus chnges to ABC. --- chaco/abstract_data_source.py | 184 ++++--- chaco/base.py | 29 +- chaco/base_array_data_source.py | 487 ++++++++++++++++++ .../tests/base_array_data_source_test_case.py | 423 +++++++++++++++ 4 files changed, 1037 insertions(+), 86 deletions(-) create mode 100644 chaco/base_array_data_source.py create mode 100644 chaco/tests/base_array_data_source_test_case.py diff --git a/chaco/abstract_data_source.py b/chaco/abstract_data_source.py index f239211ca..24594b831 100644 --- a/chaco/abstract_data_source.py +++ b/chaco/abstract_data_source.py @@ -1,136 +1,170 @@ """ Defines the AbstractDataSource class. + +This is the abstract base class for all sources which provide data to Chaco +plots and renderers. + + """ from __future__ import absolute_import, division, print_function, unicode_literals -from traits.api import Bool, Dict, Event, HasTraits +from traits.api import ABCHasTraits, Dict, Event, Int, Str # Local relative imports -from .base import DimensionTrait +from .base import ValueType -class AbstractDataSource(HasTraits): - """ This abstract interface must be implemented by any class supplying data - to Chaco. - Chaco does not have a notion of a "data format". For the most part, a data - source looks like an array of values with an optional mask and metadata. - If you implement this interface, you are responsible for adapting your - domain-specific or application-specific data to meet this interface. +class AbstractDataSource(ABCHasTraits): + """ Abstract interface for data sources used by Chaco renderers + + This abstract interface must be implemented by any class supplying data + to Chaco renderers. Chaco does not have a notion of a "data format". + For the most part, a data source looks like an array of values with an + optional mask and metadata. If you implement this interface, you are + responsible for adapting your domain-specific or application-specific data + to meet this interface. Chaco provides some basic data source implementations. In most cases, the easiest strategy is to create one of these basic data source with the numeric data from a domain model. In cases when this strategy is not possible, domain classes (or an adapter) must implement AbstractDataSource. - """ - # The dimensionality of the value at each index point. - # Subclasses re-declare this trait as a read-only trait with - # the right default value. - value_dimension = DimensionTrait + Notes + ----- - # The dimensionality of the indices into this data source. - # Subclasses re-declare this trait as a read-only trait with - # the right default value. - index_dimension = DimensionTrait + The contract implied by the AbstractDataSource interface is that data + arrays provided by the get methods of the class should not be treated as + read-only arrays, and that any change to the data or mask (such as by + subclasses which provide a `set_data` method) will be accompanied by the + `data_changed` event being fired. - # A dictionary keyed on strings. In general, it maps to indices (or tuples - # of indices, depending on **value_dimension**), as in the case of - # selections and annotations. Applications and renderers can add their own - # custom metadata, but must avoid using keys that might result in name - # collision. - metadata = Dict + """ - # Event that fires when the data values change. + #: The dimension of the values provided by the data source. + #: Implementations of the interface will typically redefine this as a + #: read-only trait with a particular value. + value_type = ValueType + + #: The dimension of the indices into the data source. + #: Implementations of the interface will typically redefine this as a + #: read-only trait with a particular value. + dimension = Int + + #: The metadata for the data source. + #: Metadata values are typically used for annotations and selections + #: on the data source, and so each keyword corresponds to a collection of + #: indices into the data source. Applications and renderers can add their + #: own custom metadata, but must avoid using keys that might result in name + #: collision. + metadata = Dict(Str) + + #: Event that fires when the data values change. data_changed = Event - # Event that fires when just the bounds change. + #: Event that fires when the bounds (ie. the extent of the values) change. bounds_changed = Event - # Event that fires when metadata structure is changed. + #: Event that fires when metadata structure is changed. metadata_changed = Event - # Should the data that this datasource refers to be serialized when - # the datasource is serialized? - persist_data = Bool(True) - #------------------------------------------------------------------------ - # Abstract methods + # AbstractDataSource interface #------------------------------------------------------------------------ def get_data(self): - """get_data() -> data_array + """Get an array representing the data stored in the data source. + + Returns + ------- - Returns a data array of the dimensions of the data source. This data - array must not be altered in-place, and the caller must assume it is - read-only. This data is contiguous and not masked. + data_array : array + An array of the dimensions specified by the index and value + dimension traits. This data array must not be altered in-place, + and the caller must assume it is read-only. This data is + contiguous and not masked. - In the case of structured (gridded) 2-D data, this method may return - two 1-D ArrayDataSources as an optimization. """ raise NotImplementedError def get_data_mask(self): - """get_data_mask() -> (data_array, mask_array) + """Get arrays representing the data and the mask of the data source. + + Returns + ------- - Returns the full, raw, source data array and a corresponding binary - mask array. Treat both arrays as read-only. + data_array, mask: array of values, array of bool + Returns the full, raw, source data array and a corresponding binary + mask array. Treat both arrays as read-only. + + The mask is a superposition of the masks of all upstream data sources. + The length of the returned array may be much larger than what + get_size() returns; the unmasked portion, however, matches what + get_size() returns. - The mask is a superposition of the masks of all upstream data sources. - The length of the returned array may be much larger than what - get_size() returns; the unmasked portion, however, matches what - get_size() returns. """ raise NotImplementedError def is_masked(self): - """is_masked() -> bool + """Whether or not the data is masked. + + Returns + ------- + + is_masked : bool + True if this data source's data uses a mask. In this case, + to retrieve the data, call get_data_mask() instead of get_data(). - Returns True if this data source's data uses a mask. In this case, - to retrieve the data, call get_data_mask() instead of get_data(). - If you call get_data() for this data source, it returns data, but that - data might not be the expected data. """ raise NotImplementedError def get_size(self): - """get_size() -> int + """The size of the data. - Returns an integer estimate or the exact size of the dataset that - get_data() returns for this object. This method is useful for - down-sampling. - """ - raise NotImplementedError + This method is useful for down-sampling. - def get_bounds(self): - """get_bounds() -> tuple(min, max) + Returns + ------- - Returns a tuple (min, max) of the bounding values for the data source. - In the case of 2-D data, min and max are 2-D points that represent the - bounding corners of a rectangle enclosing the data set. Note that - these values are not view-dependent, but represent intrinsic properties - of the data source. + size : int or tuple of ints + An estimate (or the exact size) of the dataset that get_data() + returns for this object. For data sets with n-dimensional index + values, this can return an n-tuple indicating the size in each + dimension. - If data is the empty set, then the min and max vals are 0.0. """ raise NotImplementedError + def get_bounds(self): + """Get the minimum and maximum finite values of the data. - ### Persistence ########################################################### + Returns + ------- - def _metadata_default(self): - return {"selections":[], "annotations":[]} + bounds : tuple of min, max + A tuple (min, max) of the bounding values for the data source. + In the case of n-dimensional data values, min and max are + n-dimensional points that represent the bounding corners of a + rectangle enclosing the data set. Note that these values are not + view-dependent, but represent intrinsic properties of the data + source. - def __getstate__(self): - state = super(AbstractDataSource,self).__getstate__() + Raises + ------ - # everything but 'metadata' - for key in ['value_dimension', 'index_dimension', 'persist_data']: - if state.has_key(key): - del state[key] + TypeError: + If data's value type is not amenable to sorting, a TypeError can + be raised. - return state + ValueError: + If data is empty, all NaN, or otherwise has no sensible ordering, + then this should raise a ValueError. + + """ + raise NotImplementedError + ### Trait defaults ####################################################### -# EOF + def _metadata_default(self): + return {"selections":[], "annotations":[]} diff --git a/chaco/base.py b/chaco/base.py index db7edb4e0..2b8f10c60 100644 --- a/chaco/base.py +++ b/chaco/base.py @@ -9,33 +9,40 @@ # Major library imports from numpy import (array, argsort, concatenate, cos, dot, empty, nonzero, - pi, sin, take, ndarray) + pi, sin, take, ndarray, number) # Enthought library imports -from traits.api import CArray, Enum, Trait +from traits.api import ArrayOrNone, Enum +# Exceptions + +class DataUpdateError(RuntimeError): + pass + +class DataInvalidError(ValueError): + pass # Dimensions # A single array of numbers. -NumericalSequenceTrait = Trait(None, None, CArray(value=empty(0))) +NumericalSequenceTrait = ArrayOrNone(shape=(None,), value=empty(0)) # A sequence of pairs of numbers, i.e., an Nx2 array. -PointTrait = Trait(None, None, CArray(value=empty(0))) +PointTrait = ArrayOrNone(shape=(None, 2), value=empty(shape=(0, 2))) # An NxM array of numbers. -ImageTrait = Trait(None, None, CArray(value=empty(0))) +ImageTrait = ArrayOrNone(shape=(None, None), value=empty(shape=(0, 0))) # An 3D array of numbers of shape (Nx, Ny, Nz) -CubeTrait = Trait(None, None, CArray(value=empty(0))) - +CubeTrait = ArrayOrNone(shape=(None, None, None), value=empty(shape=(0, 0, 0))) -# This enumeration lists the fundamental mathematical coordinate types that -# Chaco supports. -DimensionTrait = Enum("scalar", "point", "image", "cube") +#: The fundamental value types that data sources can take. These can be +#: agumented by adding to `ValueType.values`. +ValueType = Enum("scalar", "point", "color", "index", "mask", "text", + "datetime") -# Linear sort order. +#: Linear sort order. SortOrderTrait = Enum("ascending", "descending", "none") diff --git a/chaco/base_array_data_source.py b/chaco/base_array_data_source.py new file mode 100644 index 000000000..cd49a3c98 --- /dev/null +++ b/chaco/base_array_data_source.py @@ -0,0 +1,487 @@ +""" +Defines the BaseArrayDataSource class. + +This is a base class that implements common logic for NumPy array-based data +sources, ie. where the underlying data is stored in a numpy array fairly +directly. + +""" + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from contextlib import contextmanager +from numpy import isfinite, isnan, nanmax, nanmin, ones + +from traits.api import ArrayOrNone, Bool, Either, Instance, Tuple + +from .abstract_data_source import AbstractDataSource +from .base import DataInvalidError, DataUpdateError + + +class BaseArrayDataSource(AbstractDataSource): + """ Base class for data sources which store data in a NumPy array + + This class provides basic implementation of the AbstractDataSource + interface on top of a numpy array. The class also guards against + accessing the data while a change to the data is under way. + + Parameters + ---------- + + data : array-like or None (default is None) + If None, the current data future queries to get_data will return an + appropriate empty data object. Otherwise the data must be an + array-like compatible with the dimension and value type. + + mask : array-like of bool or None (default is None) + If None, this clears the mask. Otherwise the mask must be an + array-like compatible with the dimension of the data source and the + shape of the underlying data array. + + Notes + ----- + + This class is abstract and shouldn't be instantiated directly. This class + also should not be used as an interface: plots and renderers shouldn't care + about the mechanics of data source internals (ie. whether the data is in + an array) but on the dimensionality, value type, masking, etc. + + Subclasses must provide valid `dimension` and `value_type` traits, and + implement the private `_empty_data()` method to return an appropriate + value when the data is set to `None` (usually an empty array of the + correct dimensionality, but with zeroed shape). + + The constructor does not check that the data array and mask array have the + correct dimension, value type or compatible shapes. Subclasses should use + appropriate trait types to ensure that the underlying arrays have + appropriate dimension and value type. + + The implementation is designed for arrays that fit comfortably in memory + and, in particular, for which is is not an imposition to create temporary + copies and masks, since the `nanmin` and `nanmax` functions do that under + the covers. + + """ + + #------------------------------------------------------------------------ + # BaseArrayDataSource interface + #------------------------------------------------------------------------ + + def __init__(self, data=None, mask=None, **traits): + super(BaseArrayDataSource, self).__init__(**traits) + self.set_data(data, mask) + + def set_data(self, data, mask=None): + """ Set the value of the data and (optional) mask. + + This method should be used if an atomic update of data and mask is + required. + + Parameters + ---------- + + data : array-like or None + If None, the current data is removed and future queries to + get_data will return an appropriate empty data object. Otherwise + the data must be an array-like compatible with the dimension and + value type. + + mask : array-like of bool or None + If None, leaves the mask as-is. Otherwise the mask must be an + array-like compatible with the dimension of the data source and the + shape of the underlying data array. + + Raises + ------ + + DataUpdateError: + If the data is already being modified when `set_data` is called, + a `DataUpdateError` will be raised. + + Notes + ----- + + This method does not check that the data array and mask array have the + correct dimension, value type or compatible shapes. If the shapes are + incompatible, calls to `get_data_mask` will raise a `ValueError`. + + """ + with self._updating_data(): + self._data = data + self._finite_mask = None + self._cached_bounds = None + if mask is not None: + self._mask = mask + + def set_mask(self, mask): + """ Set the value of the mask + + The actual mask will be the intesection of this mask and the + finite data values. + + Parameters + ---------- + + mask : array-like of bool or None + If None, this clears the mask. Otherwise the mask must be an + array-like compatible with the dimension of the data source and + the shape of the underlying data array. + + Raises + ------ + + DataUpdateError: + If the data is already being modified when `set_mask` is called, + a `DataUpdateError` will be raised. + + Notes + ----- + + This method does not check that the data array and mask array have + compatible shapes. If the shapes are incompatible, calls to + `get_data_mask` will raise a `ValueError`. + + """ + with self._updating_data(): + self._mask = mask + + def remove_mask(self): + """ Remove the mask + + This is largely for backwards compatibility, it is equivalent to + set_mask(None). + + Raises + ------ + + DataUpdateError: + If the data is already being modified when `remove_mask` is called, + a `DataUpdateError` will be raised. + + """ + self.set_mask(None) + + def invalidate_data(self): + """ Mark the data value as invalid. + + Data can only become valid again after a successful `set_data`. + + """ + self._data_valid = False + + def access_guard(self): + """ Context manager that detects data becoming invalid during access + + This listens for changes in internal state to detect if the + data is initially invalid, or the `invalidate_data` method has been + called during the context. + + Returns + ------- + + access_guard : context manager + A context manager that raises a DataInvalidError if the data is + invalidated. + + Raises + ------ + + DataInvalidError: + When the data is invalid. + + """ + return Guard(self) + + #------------------------------------------------------------------------ + # AbstractDataSource interface + #------------------------------------------------------------------------ + + def get_data(self): + """Get an array representing the data stored in the data source. + + Returns + ------- + + data_array : array of values + An array of the dimensions specified by the index and value + dimension traits. This data array must not be altered in-place, + and the caller must assume it is read-only. This data is + contiguous and not masked. + + """ + with self.access_guard(): + data = self._get_data_unsafe() + return data + + def get_data_mask(self): + """Get arrays representing the data and the mask of the data source. + + Returns + ------- + + data_array, mask: array of values, array of bool + Returns the full source data array and a corresponding binary + mask array. Treat both arrays as read-only. + + Raises + ------ + + ValueError: + If mask's shape is incompatible with the data shape. + + """ + with self.access_guard(): + data, mask = self._get_data_mask_unsafe() + + return data, mask + + def is_masked(self): + """Whether or not the data is masked. + + Returns + ------- + + is_masked : bool + True if this data source's data uses a mask of has non-finite + values. + + Raises + ------ + + ValueError: + If mask's shape is incompatible with the data shape. + + """ + with self.access_guard(): + self._compute_finite() + masked = self._mask is not None or not self._finite + return masked + + def get_size(self): + """The size of the data. + + This method is useful for down-sampling. + + Returns + ------- + + size : tuple of ints + Returns the shape of the data for the index dimensions. + + """ + with self.access_guard(): + data = self._get_data_unsafe() + + return data.shape[:self.dimension] + + def get_bounds(self): + """ Get the minimum and maximum finite values of the data. + + Returns + ------- + + bounds : tuple of min, max + A tuple (min, max) of the bounding values for the data source. + In the case of n-dimensional data values, min and max are + n-dimensional points that represent the bounding corners of a + rectangle enclosing the data set. + + Raises + ------ + + ValueError: + If an all-nan axis is found. + + TypeError: + If the dtype is not appropriate for min and max (eg. strings) + + """ + with self.access_guard(): + if self._cached_bounds is None: + self._compute_bounds() + bounds = self._cached_bounds + + return bounds + + #------------------------------------------------------------------------ + # Event handlers + #------------------------------------------------------------------------ + + def _metadata_changed(self, event): + self.metadata_changed = True + + def _metadata_items_changed(self, event): + self.metadata_changed = True + + #------------------------------------------------------------------------ + # Private interface + #------------------------------------------------------------------------ + + #: the array holding the data. This should be overridden with something + #: more specific in concrete subclasses. + _data = ArrayOrNone + + #: the user-supplied mask. + _mask = ArrayOrNone(dtype=bool) + + #: locations where the data is finite + _finite_mask = ArrayOrNone(dtype=bool) + + #: the actual mask taking into account non-finite values. + _cached_mask = ArrayOrNone(dtype=bool) + + #: whether or not the data is finite at all locations + _finite = Bool + + #: the min, max bounds of the (unmasked) data + _cached_bounds = Either(Tuple, None) + + #: a flag indicating whether the data is currently valid + _data_valid = Bool(False) + + #: a lock that should be acquired before modifying the data. This should + #: usually be acquired in a non-blocking fashion + _update_lock = Instance('threading.Lock', args=()) + + def _get_data_unsafe(self): + """ Return the data without aquiring the update lock + + This method can be safely called by subclasses via super() without + worrying about locking. + + """ + data = self._data + if data is None: + return self._empty_data() + return data + + def _get_data_mask_unsafe(self): + """ Return the data without aquiring the update lock + + This method can be safely called by subclasses via super() without + worrying about locking. + + """ + data = self._get_data_unsafe() + self._compute_mask() + mask = self._cached_mask + return data, mask + + def _empty_data(self): + """ Method that returns an empty array of the appropriate type + + Concrete subclasses must implement this. + + """ + raise NotImplementedError + + def _compute_mask(self): + """ Compute the mask and cache it """ + if self._cached_mask is None: + self._compute_finite() + if self._mask is None: + self._cached_mask = self._finite_mask + elif self._finite: + self._cached_mask = self._mask + else: + self._cached_mask = self._finite_mask & self._mask + + def _compute_finite(self): + """ Compute locations where data is finite + + Subclasses with complex dtypes may need to override this method. + + """ + if self._finite_mask is None: + data = self._get_data_unsafe() + non_index_axes = tuple(range(self.dimension, len(data.shape))) + try: + self._finite_mask = isfinite(data).all(axis=non_index_axes) + self._finite = self._finite_mask.all() + except TypeError: + # dtype for which isfinite doesn't work; finite by definition + self._finite_mask = ones(shape=data.shape[:self.dimension], + dtype=bool) + self._finite = True + + def _compute_bounds(self): + """ Compute bounds of values + + Subclasses may override this to avoid un-needed computation in cases + where minimum and maximum values are known (eg. sorted data). + + Raises + ------ + + ValueError: + If an all-nan axis is found. + + TypeError: + If the dtype is not appropriate for min and max (eg. strings) + + """ + data = self._get_data_unsafe() + + index_axes = tuple(range(self.dimension)) + min_value = nanmin(data, axis=index_axes) + max_value = nanmax(data, axis=index_axes) + + # only need to check min value: nans in max_value imply nans here too + if isnan(min_value).any(): + raise ValueError("All-NaN axis encountered") + + self._cached_bounds = (min_value, max_value) + + + @contextmanager + def _updating_data(self): + """ Context manager for updating data + + In the enter method this manager attempts to acquire the update lock, + raising an exception if it is not immediately available, then sets the + validity of the data to False. + + In the exit method it ensures that the update lock is released, and + sets the + """ + acquired = self._update_lock.acquire(False) + if not acquired: + msg = "data update conflict, {} cannot acquire lock" + raise DataUpdateError(msg.format(self)) + try: + self.invalidate_data() + yield self + self._cached_mask = None + self._data_valid = True + finally: + self._update_lock.release() + + self.data_changed = True + + +class Guard(object): + """ Guard against trait becoming invalid during access + + Lightweight context manager that raises an exception if trait becomes + invalid during an operation. + + """ + + def __init__(self, obj, valid_trait='_data_valid'): + self.valid_trait = valid_trait + self.valid = True + self.obj = obj + + def valid_change_listener(self, new): + self.valid = self.valid and new + + def __enter__(self): + self.obj.on_trait_change(self.valid_change_listener, self.valid_trait) + self.valid = self.valid and getattr(self.obj, self.valid_trait) + return self + + def __exit__(self, exc_type, exc_value, tb): + if not self.valid: + msg = "data source {} data is not valid" + raise DataInvalidError(msg.format(self)) diff --git a/chaco/tests/base_array_data_source_test_case.py b/chaco/tests/base_array_data_source_test_case.py new file mode 100644 index 000000000..03258374a --- /dev/null +++ b/chaco/tests/base_array_data_source_test_case.py @@ -0,0 +1,423 @@ +""" +Test cases for the BaseArrayDataSource class. +""" + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import unittest2 as unittest +import mock +from numpy import arange, array, empty, inf, isfinite, issubdtype, NaN, ones +from numpy.testing import assert_array_equal + +from traits.api import ReadOnly, TraitError +from traits.testing.unittest_tools import UnittestTools + +from chaco.base import DataInvalidError, DataUpdateError +from chaco.base_array_data_source import BaseArrayDataSource + + +class TestArrayDataSource(BaseArrayDataSource): + """ Very simple implementation to test basic functionality """ + + #: a 1-D array + dimension = ReadOnly(1) + + #: with scalar values + value_type = ReadOnly('scalar') + + +class TestVectorArrayDataSource(BaseArrayDataSource): + """ Very simple implementation to test basic functionality """ + + #: a 1-D array + dimension = ReadOnly(1) + + #: with scalar values + value_type = ReadOnly('vector') + + +class BaseArrayDataSourceTestCase(unittest.TestCase, UnittestTools): + """ Test cases for the BaseArrayDataSource class. """ + + def setUp(self): + self.data = arange(12.0) + self.data[5] = NaN + self.data_source = TestArrayDataSource(data=self.data) + self.data_source._empty_data = mock.MagicMock( + return_value=empty((0,))) + + self.mask = (self.data % 2 == 0) + self.masked_data_source = TestArrayDataSource(data=self.data, + mask=self.mask) + self.masked_data_source._empty_data = mock.MagicMock( + return_value=empty((0,))) + + self.empty_data_source = TestArrayDataSource() + self.empty_data_source._empty_data = mock.MagicMock( + return_value=empty((0,))) + + # something like an array of colors + self.vector_data = arange(12.0).reshape(4, 3) + self.vector_data_source = TestVectorArrayDataSource(self.vector_data) + self.vector_data_source._empty_data = mock.MagicMock( + return_value=empty((0, 3))) + + # an array of ints + self.int_data = arange(12) + self.int_data_source = TestArrayDataSource(data=self.int_data, + mask=self.mask) + self.int_data_source._empty_data = mock.MagicMock( + return_value=empty((0,), dtype=int)) + + # an array of strings + self.text_data = array(['zero', 'one', 'two', 'three', 'four', 'five', + 'six', 'seven', 'eight', 'nine', 'ten', + 'eleven']) + self.text_data_source = TestArrayDataSource(data=self.text_data, + mask=self.mask) + self.text_data_source._empty_data = mock.MagicMock( + return_value=empty((0,), dtype='S1')) + + + def test_initialize(self): + self.validate_data(self.data_source, self.data) + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + + def test_initialize_masked(self): + self.validate_data(self.masked_data_source, self.data, self.mask) + self.assertEquals(self.masked_data_source.get_bounds(), (0.0, 11.0)) + + def test_initialize_empty(self): + data_source = self.empty_data_source + self.validate_data(data_source, None) + + self.assertEquals(data_source.get_size(), (0,)) + self.assertTrue(data_source._empty_data.called) + + with self.assertRaises(ValueError): + data_source.get_bounds() + + def test_initialize_vector(self): + data_source = self.vector_data_source + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds(), + array([[0.0, 1.0, 2.0], [9.0, 10.0, 11.0]])) + + def test_initialize_int(self): + data_source = self.int_data_source + self.validate_data(data_source, self.int_data, self.mask) + assert_array_equal(data_source.get_bounds(), (0, 11)) + + def test_initialize_text(self): + data_source = self.text_data_source + self.validate_data(data_source, self.text_data, self.mask) + with self.assertRaises(TypeError): + data_source.get_bounds() + + def test_vector_nan(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = NaN + data_source.set_data(self.vector_data) + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds(), + array([[0.0, 1.0, 2.0], [9.0, 10.0, 11.0]])) + + def test_set_data(self): + data_source = self.data_source + new_data = arange(15.0) + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(new_data) + + self.validate_data(data_source, new_data) + self.assertEquals(data_source.get_bounds(), (0.0, 14.0)) + + def test_set_data_masked(self): + data_source = self.data_source + new_data = arange(15.0) + new_mask = (new_data % 2 == 0) + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(new_data, new_mask) + + self.validate_data(data_source, new_data, new_mask) + self.assertEquals(data_source.get_bounds(), (0.0, 14.0)) + + def test_set_data_none(self): + data_source = self.data_source + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(None) + + self.validate_data(data_source, None) + + with self.assertRaises(ValueError): + self.empty_data_source.get_bounds() + + def test_set_data_all_nan(self): + data_source = self.data_source + new_data = empty(shape=(12,)) + new_data[:] = NaN + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(new_data) + + self.validate_data(data_source, new_data) + + with self.assertRaises(ValueError): + self.empty_data_source.get_bounds() + + def test_set_data_invalid(self): + data_source = self.data_source + + with self.assertRaises(TraitError): + with self.assertTraitDoesNotChange(data_source, 'data_changed'): + data_source.set_data('invalid data') + + self.check_invalid(data_source) + + # now check that we can reset to valid state + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(self.data) + + self.validate_data(data_source, self.data) + + def test_set_data_update_lock_fail(self): + data_source = self.data_source + new_data = arange(15.0) + + with self.assertTraitDoesNotChange(data_source, 'data_changed'): + with self.assertRaises(DataUpdateError): + with data_source._update_lock: + data_source.set_data(new_data) + + # data should be unmodified by failed update + self.validate_data(data_source, self.data) + + def test_masked_data_set_data(self): + data_source = self.data_source + new_data = arange(0, 24.0, 2) + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(new_data) + + self.validate_data(data_source, new_data) + self.assertEquals(data_source.get_bounds(), (0.0, 22.0)) + + def test_set_mask(self): + data_source = self.data_source + new_mask = (self.data % 3 == 0) + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_mask(new_mask) + + self.validate_data(data_source, self.data, new_mask) + self.assertEquals(data_source.get_bounds(), (0.0, 11.0)) + + def test_set_mask_update_lock_fail(self): + data_source = self.data_source + new_mask = (self.data % 3 == 0) + + with self.assertTraitDoesNotChange(data_source, 'data_changed'): + with self.assertRaises(DataUpdateError): + with data_source._update_lock: + data_source.set_mask(new_mask) + + def test_remove_mask(self): + data_source = self.masked_data_source + + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.remove_mask() + + self.validate_data(data_source, self.data) + self.assertEquals(data_source.get_bounds(), (0.0, 11.0)) + + def test_invalidate_data(self): + data_source = self.data_source + + with self.assertTraitDoesNotChange(data_source, 'data_changed'): + data_source.invalidate_data() + + self.check_invalid(data_source) + + # now check that we can reset to valid state + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(self.data) + + self.validate_data(data_source, self.data) + + def test_access_guard(self): + data_source = self.data_source + + with self.assertRaises(DataInvalidError): + with data_source.access_guard(): + data_source.invalidate_data() + + self.check_invalid(data_source) + + # now check that we can reset to valid state + with self.assertTraitChanges(data_source, 'data_changed', 1): + data_source.set_data(self.data) + + self.validate_data(data_source, self.data) + + def test_data_mask_incompatible(self): + data_source = TestArrayDataSource(self.data, (arange(15.0) % 3 == 0)) + + with self.assertRaises(ValueError): + data_source.get_data_mask() + + def test_get_bounds_plus_infinity(self): + self.data[3] = inf + data_source = TestArrayDataSource(self.data) + + self.assertEquals(data_source.get_bounds(), (0, inf)) + + def test_get_bounds_all_plus_infinity(self): + self.data[:] = inf + data_source = TestArrayDataSource(self.data) + + self.assertEquals(data_source.get_bounds(), (inf, inf)) + + def test_get_bounds_minus_infinity(self): + self.data[3] = -inf + data_source = TestArrayDataSource(self.data) + + self.assertEquals(data_source.get_bounds(), (-inf, 11)) + + def test_get_bounds_all_minus_infinity(self): + self.data[:] = -inf + data_source = TestArrayDataSource(self.data) + + self.assertEquals(data_source.get_bounds(), (-inf, -inf)) + + def test_get_bounds_plus_minus_infinity(self): + self.data[3] = inf + self.data[7] = -inf + data_source = TestArrayDataSource(self.data) + + self.assertEquals(data_source.get_bounds(), (-inf, inf)) + + def test_get_bounds_vector_plus_infinity(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = inf + data_source.set_data(self.vector_data) + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds()[0], + array([0.0, 1.0, 2.0])) + assert_array_equal(data_source.get_bounds()[1], + array([9.0, 10.0, inf])) + + def test_get_bounds_vector_minus_infinity(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = -inf + data_source.set_data(self.vector_data) + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds()[0], + array([0.0, 1.0, -inf])) + assert_array_equal(data_source.get_bounds()[1], + array([9.0, 10.0, 11.0])) + + def test_get_bounds_vector_plus_minus_infinity(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = -inf + self.vector_data[2, 1] = inf + data_source.set_data(self.vector_data) + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds()[0], + array([0.0, 1.0, -inf])) + assert_array_equal(data_source.get_bounds()[1], + array([9.0, inf, 11.0])) + + def test_get_bounds_vector_nan_vector(self): + data_source = self.vector_data_source + self.vector_data[:, 2] = NaN + data_source.set_data(self.vector_data) + + self.validate_data(data_source, self.vector_data) + with self.assertRaises(ValueError): + data_source.get_bounds() + + def test_metadata_changed(self): + with self.assertTraitChanges(self.data_source, 'metadata_changed', + count=1): + self.data_source.metadata = {'new_metadata': True} + + def test_metadata_items_changed(self): + with self.assertTraitChanges(self.data_source, 'metadata_changed', + count=1): + self.data_source.metadata['new_metadata'] = True + + def test_empty_data_not_implemented(self): + data_source = BaseArrayDataSource() + + with self.assertRaises(NotImplementedError): + data_source.get_data() + + #### Common validation methods ########################################### + + def validate_data(self, data_source, expected_data, expected_mask=None): + if expected_data is None: + expected_empty = True + else: + expected_empty = False + expected_shape = expected_data.shape[:data_source.dimension] + + if expected_mask is None: + if not expected_empty: + axes = tuple(range(1, len(expected_data.shape))) + try: + expected_mask = isfinite(expected_data).all(axis=axes) + except TypeError: + expected_mask = ones(expected_data.shape, dtype=bool) + expected_is_masked = not expected_mask.all() + else: + expected_is_masked = False + else: + if not expected_empty: + axes = tuple(range(1, len(expected_data.shape))) + try: + expected_mask &= isfinite(expected_data).all(axis=axes) + except TypeError: + pass + expected_is_masked = True + + # check that get_data() works + data = data_source.get_data() + + self.assertEqual(data_source._empty_data.called, expected_empty) + if not expected_empty: + assert_array_equal(data, expected_data) + + # check that get_data_mask() works + data, mask = data_source.get_data_mask() + + self.assertEqual(data_source._empty_data.called, expected_empty) + if not expected_empty: + assert_array_equal(data, expected_data) + assert_array_equal(mask, expected_mask) + + # check that is_masked() works + self.assertEqual(data_source.is_masked(), expected_is_masked) + + # check that get_size() works + if not expected_empty: + self.assertEquals(data_source.get_size(), expected_shape) + + def check_invalid(self, data_source): + # check that methods to get data fail + with self.assertRaises(DataInvalidError): + data_source.get_data() + + with self.assertRaises(DataInvalidError): + data_source.get_data_mask() + + with self.assertRaises(DataInvalidError): + data_source.is_masked() + + with self.assertRaises(DataInvalidError): + data_source.get_size() + + with self.assertRaises(DataInvalidError): + data_source.get_bounds() From 1423fb14cc202bef7926789162505775bc2561c2 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Mon, 29 Dec 2014 18:55:43 +0000 Subject: [PATCH 27/30] Remove references to DimensionTrait. --- chaco/api.py | 2 +- chaco/image_data.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/chaco/api.py b/chaco/api.py index e2b0557c8..3e5014a09 100644 --- a/chaco/api.py +++ b/chaco/api.py @@ -5,7 +5,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals -from .base import NumericalSequenceTrait, PointTrait, ImageTrait, DimensionTrait, \ +from .base import NumericalSequenceTrait, PointTrait, ImageTrait, \ SortOrderTrait, bin_search, reverse_map_1d, right_shift, \ left_shift, sort_points, find_runs, arg_find_runs, \ point_line_distance diff --git a/chaco/image_data.py b/chaco/image_data.py index 25e44181d..82fecf6ce 100644 --- a/chaco/image_data.py +++ b/chaco/image_data.py @@ -10,7 +10,7 @@ from traits.api import Bool, Int, Property, ReadOnly, Tuple # Local relative imports -from .base import DimensionTrait, ImageTrait +from .base import ImageTrait from .abstract_data_source import AbstractDataSource class ImageData(AbstractDataSource): @@ -22,7 +22,7 @@ class ImageData(AbstractDataSource): on the context in which the ImageData instance will be used. """ # The dimensionality of the data. - dimension = ReadOnly(DimensionTrait('image')) + dimension = ReadOnly(2) # Depth of the values at each i,j. Values that are used include: # From 23a0df3ed92458e42ef813b12070c7cf7750e3d3 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Mon, 29 Dec 2014 19:13:29 +0000 Subject: [PATCH 28/30] More careful use of format. --- chaco/base_array_data_source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chaco/base_array_data_source.py b/chaco/base_array_data_source.py index cd49a3c98..6e1eaad94 100644 --- a/chaco/base_array_data_source.py +++ b/chaco/base_array_data_source.py @@ -447,7 +447,7 @@ def _updating_data(self): """ acquired = self._update_lock.acquire(False) if not acquired: - msg = "data update conflict, {} cannot acquire lock" + msg = "data update conflict, {0} cannot acquire lock" raise DataUpdateError(msg.format(self)) try: self.invalidate_data() @@ -483,5 +483,5 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, tb): if not self.valid: - msg = "data source {} data is not valid" + msg = "data source {0} data is not valid" raise DataInvalidError(msg.format(self)) From ffc4953b014c3bf65b0fb8ce56f072fd95e227bc Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Mon, 5 Jan 2015 11:03:37 +0000 Subject: [PATCH 29/30] Refactor BaseArrayDataSource into BaseDataSource. --- chaco/abstract_data_source.py | 4 +- chaco/base.py | 43 +- chaco/base_array_data_source.py | 330 ++------------ chaco/base_data_source.py | 410 ++++++++++++++++++ .../tests/base_array_data_source_test_case.py | 4 +- chaco/tests/base_data_source_test_case.py | 358 +++++++++++++++ 6 files changed, 821 insertions(+), 328 deletions(-) create mode 100644 chaco/base_data_source.py create mode 100644 chaco/tests/base_data_source_test_case.py diff --git a/chaco/abstract_data_source.py b/chaco/abstract_data_source.py index 24594b831..58b2dbc4d 100644 --- a/chaco/abstract_data_source.py +++ b/chaco/abstract_data_source.py @@ -44,12 +44,12 @@ class AbstractDataSource(ABCHasTraits): #: The dimension of the values provided by the data source. #: Implementations of the interface will typically redefine this as a #: read-only trait with a particular value. - value_type = ValueType + value_type = ValueType('scalar') #: The dimension of the indices into the data source. #: Implementations of the interface will typically redefine this as a #: read-only trait with a particular value. - dimension = Int + dimension = Int(1) #: The metadata for the data source. #: Metadata values are typically used for annotations and selections diff --git a/chaco/base.py b/chaco/base.py index 2b8f10c60..bbd108c04 100644 --- a/chaco/base.py +++ b/chaco/base.py @@ -8,11 +8,11 @@ from math import radians, sqrt # Major library imports -from numpy import (array, argsort, concatenate, cos, dot, empty, nonzero, - pi, sin, take, ndarray, number) +from numpy import (array, argsort, concatenate, column_stack, cos, dot, empty, + nonzero, pi, sin, take, ndarray, number) # Enthought library imports -from traits.api import ArrayOrNone, Enum +from traits.api import ArrayOrNone, Either, Enum # Exceptions @@ -23,16 +23,24 @@ class DataUpdateError(RuntimeError): class DataInvalidError(ValueError): pass +class DataBoundsError(ValueError): + pass + # Dimensions # A single array of numbers. NumericalSequenceTrait = ArrayOrNone(shape=(None,), value=empty(0)) +# A single array of arbitrary length vectors, or a collection of sequences. +SequenceVectorTrait = ArrayOrNone(shape=(None, None), value=empty(shape=(0, 0))) + # A sequence of pairs of numbers, i.e., an Nx2 array. -PointTrait = ArrayOrNone(shape=(None, 2), value=empty(shape=(0, 2))) +PointSequenceTrait = ArrayOrNone(shape=(None, 2), value=empty(shape=(0, 2))) # An NxM array of numbers. -ImageTrait = ArrayOrNone(shape=(None, None), value=empty(shape=(0, 0))) +ScalarImageTrait = ArrayOrNone(shape=(None, None), value=empty(shape=(0, 0))) +ColorImageTrait = ArrayOrNone(shape=(None, None, (3, 4)), value=empty(shape=(0, 0, 3))) +ImageTrait = Either(ScalarImageTrait, ColorImageTrait) # An 3D array of numbers of shape (Nx, Ny, Nz) CubeTrait = ArrayOrNone(shape=(None, None, None), value=empty(shape=(0, 0, 0))) @@ -69,35 +77,25 @@ def n_gon(center, r, nsides, rot_degrees=0): return [poly_point(center, r, i*theta+rotation) for i in range(nsides)] -# Ripped from Chaco 1.0's plot_base.py def bin_search(values, value, ascending): """ Performs a binary search of a sorted array looking for a specified value. Returns the lowest position where the value can be found or where the array value is the last value less (greater) than the desired value. - Returns -1 if *value* is beyond the minimum or maximum of *values*. + Returns -1 if `value` is beyond the minimum or maximum of `values`. """ - ascending = ascending > 0 - if ascending: + if ascending > 0: if (value < values[0]) or (value > values[-1]): return -1 + index = values.searchsorted(value, 'right') - 1 else: if (value < values[-1]) or (value > values[0]): return -1 - lo = 0 - hi = len( values ) - while True: - mid = (hi + lo) // 2 - midval = values[ mid ] - if midval == value: - return mid - elif (ascending and midval > value) or (not ascending and midval < value): - hi = mid - else: - lo = mid - if lo >= (hi - 1): - return lo + ascending_values = values[::-1] + index = len(values) - ascending_values.searchsorted(value, 'left') - 1 + return index + def reverse_map_1d(data, pt, sort_order, floor_only=False): """Returns the index of *pt* in the array *data*. @@ -129,7 +127,6 @@ def reverse_map_1d(data, pt, sort_order, floor_only=False): if ndx == -1: raise IndexError("value outside array data range") - # Now round the index to the closest matching index. Do this # by determining the width (in value space) of each cell and # figuring out which side of the midpoint pt falls into. Since diff --git a/chaco/base_array_data_source.py b/chaco/base_array_data_source.py index 6e1eaad94..12bffca03 100644 --- a/chaco/base_array_data_source.py +++ b/chaco/base_array_data_source.py @@ -10,16 +10,12 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals -from contextlib import contextmanager -from numpy import isfinite, isnan, nanmax, nanmin, ones +from traits.api import ArrayOrNone, Array, Property, cached_property -from traits.api import ArrayOrNone, Bool, Either, Instance, Tuple +from .base_data_source import BaseDataSource -from .abstract_data_source import AbstractDataSource -from .base import DataInvalidError, DataUpdateError - -class BaseArrayDataSource(AbstractDataSource): +class BaseArrayDataSource(BaseDataSource): """ Base class for data sources which store data in a NumPy array This class provides basic implementation of the AbstractDataSource @@ -70,7 +66,7 @@ class BaseArrayDataSource(AbstractDataSource): def __init__(self, data=None, mask=None, **traits): super(BaseArrayDataSource, self).__init__(**traits) - self.set_data(data, mask) + self.set_data(data, mask=mask) def set_data(self, data, mask=None): """ Set the value of the data and (optional) mask. @@ -107,12 +103,10 @@ def set_data(self, data, mask=None): incompatible, calls to `get_data_mask` will raise a `ValueError`. """ - with self._updating_data(): + with self.updating_data(): self._data = data - self._finite_mask = None - self._cached_bounds = None if mask is not None: - self._mask = mask + self._user_mask = mask def set_mask(self, mask): """ Set the value of the mask @@ -143,8 +137,8 @@ def set_mask(self, mask): `get_data_mask` will raise a `ValueError`. """ - with self._updating_data(): - self._mask = mask + with self.updating_data(): + self._user_mask = mask def remove_mask(self): """ Remove the mask @@ -162,158 +156,6 @@ def remove_mask(self): """ self.set_mask(None) - def invalidate_data(self): - """ Mark the data value as invalid. - - Data can only become valid again after a successful `set_data`. - - """ - self._data_valid = False - - def access_guard(self): - """ Context manager that detects data becoming invalid during access - - This listens for changes in internal state to detect if the - data is initially invalid, or the `invalidate_data` method has been - called during the context. - - Returns - ------- - - access_guard : context manager - A context manager that raises a DataInvalidError if the data is - invalidated. - - Raises - ------ - - DataInvalidError: - When the data is invalid. - - """ - return Guard(self) - - #------------------------------------------------------------------------ - # AbstractDataSource interface - #------------------------------------------------------------------------ - - def get_data(self): - """Get an array representing the data stored in the data source. - - Returns - ------- - - data_array : array of values - An array of the dimensions specified by the index and value - dimension traits. This data array must not be altered in-place, - and the caller must assume it is read-only. This data is - contiguous and not masked. - - """ - with self.access_guard(): - data = self._get_data_unsafe() - return data - - def get_data_mask(self): - """Get arrays representing the data and the mask of the data source. - - Returns - ------- - - data_array, mask: array of values, array of bool - Returns the full source data array and a corresponding binary - mask array. Treat both arrays as read-only. - - Raises - ------ - - ValueError: - If mask's shape is incompatible with the data shape. - - """ - with self.access_guard(): - data, mask = self._get_data_mask_unsafe() - - return data, mask - - def is_masked(self): - """Whether or not the data is masked. - - Returns - ------- - - is_masked : bool - True if this data source's data uses a mask of has non-finite - values. - - Raises - ------ - - ValueError: - If mask's shape is incompatible with the data shape. - - """ - with self.access_guard(): - self._compute_finite() - masked = self._mask is not None or not self._finite - return masked - - def get_size(self): - """The size of the data. - - This method is useful for down-sampling. - - Returns - ------- - - size : tuple of ints - Returns the shape of the data for the index dimensions. - - """ - with self.access_guard(): - data = self._get_data_unsafe() - - return data.shape[:self.dimension] - - def get_bounds(self): - """ Get the minimum and maximum finite values of the data. - - Returns - ------- - - bounds : tuple of min, max - A tuple (min, max) of the bounding values for the data source. - In the case of n-dimensional data values, min and max are - n-dimensional points that represent the bounding corners of a - rectangle enclosing the data set. - - Raises - ------ - - ValueError: - If an all-nan axis is found. - - TypeError: - If the dtype is not appropriate for min and max (eg. strings) - - """ - with self.access_guard(): - if self._cached_bounds is None: - self._compute_bounds() - bounds = self._cached_bounds - - return bounds - - #------------------------------------------------------------------------ - # Event handlers - #------------------------------------------------------------------------ - - def _metadata_changed(self, event): - self.metadata_changed = True - - def _metadata_items_changed(self, event): - self.metadata_changed = True - #------------------------------------------------------------------------ # Private interface #------------------------------------------------------------------------ @@ -323,29 +165,13 @@ def _metadata_items_changed(self, event): _data = ArrayOrNone #: the user-supplied mask. - _mask = ArrayOrNone(dtype=bool) - - #: locations where the data is finite - _finite_mask = ArrayOrNone(dtype=bool) - - #: the actual mask taking into account non-finite values. - _cached_mask = ArrayOrNone(dtype=bool) - - #: whether or not the data is finite at all locations - _finite = Bool + _user_mask = ArrayOrNone(dtype=bool) - #: the min, max bounds of the (unmasked) data - _cached_bounds = Either(Tuple, None) - - #: a flag indicating whether the data is currently valid - _data_valid = Bool(False) - - #: a lock that should be acquired before modifying the data. This should - #: usually be acquired in a non-blocking fashion - _update_lock = Instance('threading.Lock', args=()) + #: the actual mask, combining the user and finite masks + _mask = Property(Array(dtype=bool), depends_on='_data_valid') def _get_data_unsafe(self): - """ Return the data without aquiring the update lock + """ Return the data without guarding against changes This method can be safely called by subclasses via super() without worrying about locking. @@ -356,17 +182,20 @@ def _get_data_unsafe(self): return self._empty_data() return data - def _get_data_mask_unsafe(self): - """ Return the data without aquiring the update lock + def _get_mask_unsafe(self): + """ Return the data without guarding against changes This method can be safely called by subclasses via super() without worrying about locking. """ - data = self._get_data_unsafe() - self._compute_mask() - mask = self._cached_mask - return data, mask + return self._mask + + def _is_masked_unsafe(self): + """ Is the data masked, without guarding against changes + + """ + return self._user_mask is not None or not self._finite def _empty_data(self): """ Method that returns an empty array of the appropriate type @@ -376,112 +205,11 @@ def _empty_data(self): """ raise NotImplementedError - def _compute_mask(self): - """ Compute the mask and cache it """ - if self._cached_mask is None: - self._compute_finite() - if self._mask is None: - self._cached_mask = self._finite_mask - elif self._finite: - self._cached_mask = self._mask - else: - self._cached_mask = self._finite_mask & self._mask - - def _compute_finite(self): - """ Compute locations where data is finite - - Subclasses with complex dtypes may need to override this method. - - """ - if self._finite_mask is None: - data = self._get_data_unsafe() - non_index_axes = tuple(range(self.dimension, len(data.shape))) - try: - self._finite_mask = isfinite(data).all(axis=non_index_axes) - self._finite = self._finite_mask.all() - except TypeError: - # dtype for which isfinite doesn't work; finite by definition - self._finite_mask = ones(shape=data.shape[:self.dimension], - dtype=bool) - self._finite = True - - def _compute_bounds(self): - """ Compute bounds of values - - Subclasses may override this to avoid un-needed computation in cases - where minimum and maximum values are known (eg. sorted data). - - Raises - ------ - - ValueError: - If an all-nan axis is found. - - TypeError: - If the dtype is not appropriate for min and max (eg. strings) - - """ - data = self._get_data_unsafe() - - index_axes = tuple(range(self.dimension)) - min_value = nanmin(data, axis=index_axes) - max_value = nanmax(data, axis=index_axes) - - # only need to check min value: nans in max_value imply nans here too - if isnan(min_value).any(): - raise ValueError("All-NaN axis encountered") - - self._cached_bounds = (min_value, max_value) - - - @contextmanager - def _updating_data(self): - """ Context manager for updating data - - In the enter method this manager attempts to acquire the update lock, - raising an exception if it is not immediately available, then sets the - validity of the data to False. - - In the exit method it ensures that the update lock is released, and - sets the - """ - acquired = self._update_lock.acquire(False) - if not acquired: - msg = "data update conflict, {0} cannot acquire lock" - raise DataUpdateError(msg.format(self)) - try: - self.invalidate_data() - yield self - self._cached_mask = None - self._data_valid = True - finally: - self._update_lock.release() - - self.data_changed = True - - -class Guard(object): - """ Guard against trait becoming invalid during access - - Lightweight context manager that raises an exception if trait becomes - invalid during an operation. - - """ - - def __init__(self, obj, valid_trait='_data_valid'): - self.valid_trait = valid_trait - self.valid = True - self.obj = obj - - def valid_change_listener(self, new): - self.valid = self.valid and new - - def __enter__(self): - self.obj.on_trait_change(self.valid_change_listener, self.valid_trait) - self.valid = self.valid and getattr(self.obj, self.valid_trait) - return self - - def __exit__(self, exc_type, exc_value, tb): - if not self.valid: - msg = "data source {0} data is not valid" - raise DataInvalidError(msg.format(self)) + @cached_property + def _get__mask(self): + if self._user_mask is None: + return self._finite_mask + elif self._finite: + return self._user_mask + else: + return self._user_mask & self._finite_mask diff --git a/chaco/base_data_source.py b/chaco/base_data_source.py new file mode 100644 index 000000000..bdb6ea666 --- /dev/null +++ b/chaco/base_data_source.py @@ -0,0 +1,410 @@ +""" +Defines the BaseDataSource class. + +This is a base class that implements common logic for reasonably safe access +to data, raising exceptions if data is modified during a get operation, and +locking to prevent simultaneous set operations on different threads. + +""" + +from __future__ import \ + absolute_import, division, print_function, unicode_literals + +from contextlib import contextmanager +from numpy import isfinite, isnan, nanmax, nanmin, ones, prod + +from traits.api import \ + Array, Bool, Event, Instance, Property, Tuple, cached_property + +from .abstract_data_source import AbstractDataSource +from .base import DataBoundsError, DataInvalidError, DataUpdateError + + +class BaseDataSource(AbstractDataSource): + """ Base class for data sources which store data in a NumPy array + + This class provides basic implementation of the AbstractDataSource + interface on top of a numpy array. The class also guards against + accessing the data while a change to the data is under way. + + Notes + ----- + + This class is abstract and shouldn't be instantiated directly. This class + also should not be used as an interface: plots and renderers shouldn't care + about the mechanics of data source internals (ie. whether the data is in + an array) but on the dimensionality, value type, masking, etc. + + Subclasses must provide valid `dimension` and `value_type` traits, and + implement the private `_empty_data()` method to return an appropriate + value when the data is set to `None` (usually an empty array of the + correct dimensionality, but with zeroed shape). + + The constructor does not check that the data array and mask array have the + correct dimension, value type or compatible shapes. Subclasses should use + appropriate trait types to ensure that the underlying arrays have + appropriate dimension and value type. + + The implementation is designed for arrays that fit comfortably in memory + and, in particular, for which is is not an imposition to create temporary + copies and masks, since the `nanmin` and `nanmax` functions do that under + the covers. + + """ + #------------------------------------------------------------------------ + # BaseDataSource interface + #------------------------------------------------------------------------ + + def invalidate_data(self): + """ Mark the data value as invalid. + + Data can only become valid again after data is successfully updated. + + """ + self._data_valid = False + + def access_guard(self): + """ Context manager that detects data becoming invalid during access + + This listens for changes in internal state to detect if the + data is initially invalid, or the `invalidate_data` method has been + called during the context. + + Returns + ------- + + access_guard : context manager + A context manager that raises a DataInvalidError if the data is + invalidated. + + Raises + ------ + + DataInvalidError: + When the data is invalid. + + """ + return Guard(self) + + @contextmanager + def updating_data(self): + """ Context manager for updating data + + Subclasses should guard data-changing operations with this context + manager. + + In the enter method this manager attempts to acquire the update lock, + raising an exception if it is not immediately available, then sets the + validity of the data to False. + + In the exit method it ensures that the update lock is released, + and if there was no error, flags the data as valid, clears the caches, + and fires the data changed event. + + """ + acquired = self._update_lock.acquire(False) + if not acquired: + msg = "data update conflict, {0} cannot acquire lock" + raise DataUpdateError(msg.format(self)) + try: + self.invalidate_data() + yield self + self._data_valid = True + finally: + self._update_lock.release() + + self.data_changed = True + + #------------------------------------------------------------------------ + # AbstractDataSource interface + #------------------------------------------------------------------------ + + def get_data(self): + """Get an array representing the data stored in the data source. + + Returns + ------- + + data_array : array of values + An array of the dimensions specified by the index and value + dimension traits. This data array must not be altered in-place, + and the caller must assume it is read-only. This data is + contiguous and not masked. + + """ + with self.access_guard(): + data = self._get_data_unsafe() + return data + + def get_data_mask(self): + """Get arrays representing the data and the mask of the data source. + + Returns + ------- + + data_array, mask: array of values, array of bool + Returns the full source data array and a corresponding binary + mask array. Treat both arrays as read-only. + + Raises + ------ + + ValueError: + If mask's shape is incompatible with the data shape. + + """ + with self.access_guard(): + data = self._get_data_unsafe() + mask = self._get_mask_unsafe() + + return data, mask + + def is_masked(self): + """Whether or not the data is masked. + + Returns + ------- + + is_masked : bool + True if this data source's data uses a mask of has non-finite + values. + + Raises + ------ + + ValueError: + If mask's shape is incompatible with the data shape. + + """ + with self.access_guard(): + is_masked = self._is_masked_unsafe() + return is_masked + + def get_shape(self): + """The size of the data. + + This method is useful for down-sampling. + + Returns + ------- + + size : tuple of ints + Returns the shape of the data for the index dimensions. + + """ + with self.access_guard(): + shape = self._get_shape_unsafe() + return shape + + def get_size(self): + """The size of the data. + + This method is useful for down-sampling. + + Returns + ------- + + size : tuple of ints + Returns the shape of the data for the index dimensions. + + """ + return prod(self.get_shape()) + + def get_bounds(self): + """ Get the minimum and maximum finite values of the data. + + Returns + ------- + + bounds : tuple of min, max + A tuple (min, max) of the bounding values for the data source. + In the case of n-dimensional data values, min and max are + n-dimensional points that represent the bounding corners of a + rectangle enclosing the data set. + + Raises + ------ + + ValueError: + If an all-nan axis is found. + + TypeError: + If the dtype is not appropriate for min and max (eg. strings) + + """ + with self.access_guard(): + bounds = self._get_bounds_unsafe() + return bounds + + #------------------------------------------------------------------------ + # Event handlers + #------------------------------------------------------------------------ + + def _metadata_changed(self, event): + self.metadata_changed = True + + def _metadata_items_changed(self, event): + self.metadata_changed = True + + #------------------------------------------------------------------------ + # Private interface + #------------------------------------------------------------------------ + + #: whether or not the data is finite at all values + _finite = Property(Bool, depends_on='_data_valid') + + #: a mask of finite values of the data. + _finite_mask = Property(Array(dtype=bool), depends_on='_data_valid') + + #: the min, max bounds of the (unmasked) data + _bounds = Property(Tuple, depends_on='_data_valid') + + #: a flag indicating whether the data is currently valid + _data_valid = Bool(False) + + #: a lock that should be acquired before modifying the data. This should + #: usually be acquired in a non-blocking fashion + _update_lock = Instance('threading.Lock', args=()) + + def _get_data_unsafe(self): + """ Return the data without aquiring the update lock + + This method can be safely called by subclasses via super() without + worrying about locking. + + """ + raise NotImplementedError + + def _get_mask_unsafe(self): + """ Return the data without aquiring the update lock + + A default method which computes where the array is finite is provided. + + Subclasses should override this method if they have more sophisticated + mask handling. This method can be safely called by subclasses via + super() without worrying about locking. + + """ + return self._finite_mask + + def _is_masked_unsafe(self): + """ Return the data without aquiring the update lock + + A default implementation which returns True if there are any infinte + values. + + Subclasses may want to override this method if they have more + sophisticated mask handling. This method can be safely called by + subclasses via super() without worrying about locking. + + """ + print('called _is_masked_unsafe') + return not self._finite + + def _get_shape_unsafe(self): + """ Return the data without aquiring the update lock + + A default implementation is provided. Subclasses may want to override + as needed. This method can be safely called by subclasses via super() + without worrying about locking. + + """ + data = self._get_data_unsafe() + return data.shape[:self.dimension] + + def _get_bounds_unsafe(self): + """ Return the data without aquiring the update lock + + A default implementation is provided. Subclasses may want to override + as needed, particularly if they have complex value types. This method + can be safely called by subclasses via super() without worrying about + locking. + + """ + return self._bounds + + @cached_property + def _get__finite_mask(self): + """ Compute the mask and cache it """ + data = self._get_data_unsafe() + non_index_axes = tuple(range(self.dimension, len(data.shape))) + try: + return isfinite(data).all(axis=non_index_axes) + except TypeError: + # dtype for which isfinite doesn't work; finite by definition + return ones(shape=self._get_shape_unsafe(), dtype=bool) + + @cached_property + def _get__finite(self): + """ Whether or not the data is finite in all positions + + This can be overridden in situations where it is known that the data + is finite, or if there is a more efficient way to compute the fact. + + """ + print('called _get_finite') + return self._finite_mask.all() + + @cached_property + def _get__bounds(self): + """ Compute bounds of values, setting the `_cached_bounds` attribute + + Subclasses may override this to avoid un-needed computation in cases + where minimum and maximum values are known (eg. sorted data). + + Returns + ------- + + bounds : tuple of (min, max) + + Raises + ------ + + DataBoundsError: + If the data is empty, has an all-NaN axis, or otherwise has a + value where the bounds can't be computed. + + TypeError: + If the dtype is not appropriate for min and max (eg. strings) + + """ + data = self._get_data_unsafe() + if prod(data.shape[:self.dimension]) == 0: + raise DataBoundsError("Empty data has no valid bounds") + + index_axes = tuple(range(self.dimension)) + min_value = nanmin(data, axis=index_axes) + max_value = nanmax(data, axis=index_axes) + + # only need to check min value: nans in max_value imply nans here too + if isnan(min_value).any(): + raise DataBoundsError("All-NaN axis encountered") + + return (min_value, max_value) + + +class Guard(object): + """ Guard against trait becoming invalid during access + + Lightweight context manager that raises an exception if trait becomes + invalid during an operation. + + """ + + def __init__(self, obj, valid_trait='_data_valid'): + self.valid_trait = valid_trait + self.valid = True + self.obj = obj + + def valid_change_listener(self, new): + self.valid = self.valid and new + + def __enter__(self): + self.obj.on_trait_change(self.valid_change_listener, self.valid_trait) + self.valid = self.valid and getattr(self.obj, self.valid_trait) + return self + + def __exit__(self, exc_type, exc_value, tb): + if not self.valid: + msg = "data source {0} data is not valid" + raise DataInvalidError(msg.format(self)) diff --git a/chaco/tests/base_array_data_source_test_case.py b/chaco/tests/base_array_data_source_test_case.py index 03258374a..1097bf22d 100644 --- a/chaco/tests/base_array_data_source_test_case.py +++ b/chaco/tests/base_array_data_source_test_case.py @@ -13,7 +13,7 @@ from traits.api import ReadOnly, TraitError from traits.testing.unittest_tools import UnittestTools -from chaco.base import DataInvalidError, DataUpdateError +from chaco.base import DataInvalidError, DataUpdateError, DataBoundsError from chaco.base_array_data_source import BaseArrayDataSource @@ -165,7 +165,7 @@ def test_set_data_all_nan(self): self.validate_data(data_source, new_data) - with self.assertRaises(ValueError): + with self.assertRaises(DataBoundsError): self.empty_data_source.get_bounds() def test_set_data_invalid(self): diff --git a/chaco/tests/base_data_source_test_case.py b/chaco/tests/base_data_source_test_case.py new file mode 100644 index 000000000..d7f3f8e1a --- /dev/null +++ b/chaco/tests/base_data_source_test_case.py @@ -0,0 +1,358 @@ +""" +Test cases for the BaseDataSource class. +""" + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import unittest2 as unittest +import mock +from numpy import arange, array, empty, inf, isfinite, NaN, ones, prod +from numpy.testing import assert_array_equal + +from traits.api import Array, ReadOnly, TraitError +from traits.testing.unittest_tools import UnittestTools + +from chaco.base import DataInvalidError, DataUpdateError, DataBoundsError +from chaco.base_data_source import BaseDataSource + + +class TestArrayDataSource(BaseDataSource): + """ Very simple implementation to test basic functionality """ + + #: a 1-D array + dimension = ReadOnly(1) + + #: with scalar values + value_type = ReadOnly('scalar') + + def __init__(self, **traits): + super(TestArrayDataSource, self).__init__(**traits) + self._data_valid = True + + +class BaseDataSourceTestCase(unittest.TestCase, UnittestTools): + """ Test cases for the BaseArrayDataSource class. """ + + def setUp(self): + self.data = arange(12.0) + self.data[5] = NaN + self.data_source = TestArrayDataSource() + self.data_source._get_data_unsafe = mock.MagicMock( + return_value=self.data) + + self.empty_data = empty(shape=(0,), dtype=float) + self.empty_data_source = TestArrayDataSource() + self.empty_data_source._get_data_unsafe = mock.MagicMock( + return_value=self.empty_data) + + # something like an array of colors + self.vector_data = arange(12.0).reshape(4, 3) + self.vector_data_source = TestArrayDataSource() + self.vector_data_source._get_data_unsafe = mock.MagicMock( + return_value=self.vector_data) + + # an array of ints + self.int_data = arange(12) + self.int_data_source = TestArrayDataSource() + self.int_data_source._get_data_unsafe = mock.MagicMock( + return_value=self.int_data) + + # an array of strings + self.text_data = array(['zero', 'one', 'two', 'three', 'four', 'five', + 'six', 'seven', 'eight', 'nine', 'ten', + 'eleven']) + self.text_data_source = TestArrayDataSource() + self.text_data_source._get_data_unsafe = mock.MagicMock( + return_value=self.text_data) + + def test_data(self): + self.validate_data(self.data_source, self.data) + self.data_source._get_data_unsafe.reset_mock() + + self.data_source._get_data_unsafe.reset_mock() + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + self.assertTrue(self.data_source._get_data_unsafe.called) + + # and check caching of result of get_bounds + self.data_source._get_data_unsafe.reset_mock() + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + self.assertFalse(self.data_source._get_data_unsafe.called) + + def test_empty_data(self): + data_source = self.empty_data_source + self.validate_data(data_source, self.empty_data) + with self.assertRaises(DataBoundsError): + data_source.get_bounds() + + def test_vector_data(self): + data_source = self.vector_data_source + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds(), + array([[0.0, 1.0, 2.0], [9.0, 10.0, 11.0]])) + + def test_int_data(self): + data_source = self.int_data_source + self.validate_data(data_source, self.int_data) + assert_array_equal(data_source.get_bounds(), (0, 11)) + + def test_text_data(self): + data_source = self.text_data_source + self.validate_data(data_source, self.text_data) + with self.assertRaises(TypeError): + data_source.get_bounds() + + def test_vector_nan(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = NaN + data_source.data = self.vector_data + + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds(), + array([[0.0, 1.0, 2.0], [9.0, 10.0, 11.0]])) + + + def test_updating_data_update_lock_fail(self): + data_source = self.data_source + + with self.assertTraitDoesNotChange(data_source, 'data_changed'): + with self.assertRaises(DataUpdateError): + with data_source._update_lock: + with self.data_source.updating_data(): + pass + + # data should be unmodified by failed update, and still valid + self.validate_data(data_source, self.data) + + def test_invalidate_data(self): + data_source = self.data_source + # validate everything to ensure that caches are filled + self.validate_data(data_source, self.data) + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + + with self.assertTraitDoesNotChange(data_source, 'data_changed'): + data_source.invalidate_data() + + self.check_invalid(data_source) + + # now check that we can reset to valid state + self.data_source._get_data_unsafe.reset_mock() + + with self.assertTraitChanges(data_source, 'data_changed', 1): + with data_source.updating_data(): + pass + + # if caches are filled, is_masked() test should fail + self.validate_data(data_source, self.data) + + # if caches are filled, this should fail + self.data_source._get_data_unsafe.reset_mock() + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + self.assertTrue(self.data_source._get_data_unsafe.called) + + def test_access_guard(self): + data_source = self.data_source + # validate everything to ensure that caches are filled + self.validate_data(data_source, self.data) + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + + with self.assertRaises(DataInvalidError): + with data_source.access_guard(): + data_source.invalidate_data() + + self.check_invalid(data_source) + + # now check that we can reset to valid state + self.data_source._get_data_unsafe.reset_mock() + + with self.assertTraitChanges(data_source, 'data_changed', 1): + with data_source.updating_data(): + pass + + # if caches are filled, is_masked() test should fail + self.validate_data(data_source, self.data) + + # if caches are filled, this should fail + self.data_source._get_data_unsafe.reset_mock() + self.assertEquals(self.data_source.get_bounds(), (0.0, 11.0)) + self.assertTrue(self.data_source._get_data_unsafe.called) + + def test_get_bounds_all_nan(self): + data_source = self.data_source + self.data[:] = NaN + + self.validate_data(data_source, self.data) + + with self.assertRaises(DataBoundsError): + self.empty_data_source.get_bounds() + + def test_get_bounds_plus_infinity(self): + self.data[3] = inf + data_source = self.data_source + + self.assertEquals(data_source.get_bounds(), (0, inf)) + + def test_get_bounds_all_plus_infinity(self): + self.data[:] = inf + data_source = self.data_source + + self.assertEquals(data_source.get_bounds(), (inf, inf)) + + def test_get_bounds_minus_infinity(self): + self.data[3] = -inf + data_source = self.data_source + + self.assertEquals(data_source.get_bounds(), (-inf, 11)) + + def test_get_bounds_all_minus_infinity(self): + self.data[:] = -inf + data_source = self.data_source + + self.assertEquals(data_source.get_bounds(), (-inf, -inf)) + + def test_get_bounds_plus_minus_infinity(self): + self.data[3] = inf + self.data[7] = -inf + data_source = self.data_source + + self.assertEquals(data_source.get_bounds(), (-inf, inf)) + + def test_get_bounds_vector_plus_infinity(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = inf + data_source.data = self.vector_data + + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds()[0], + array([0.0, 1.0, 2.0])) + assert_array_equal(data_source.get_bounds()[1], + array([9.0, 10.0, inf])) + + def test_get_bounds_vector_minus_infinity(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = -inf + data_source.data = self.vector_data + + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds()[0], + array([0.0, 1.0, -inf])) + assert_array_equal(data_source.get_bounds()[1], + array([9.0, 10.0, 11.0])) + + def test_get_bounds_vector_plus_minus_infinity(self): + data_source = self.vector_data_source + self.vector_data[2, 2] = -inf + self.vector_data[2, 1] = inf + data_source.data = self.vector_data + + self.validate_data(data_source, self.vector_data) + assert_array_equal(data_source.get_bounds()[0], + array([0.0, 1.0, -inf])) + assert_array_equal(data_source.get_bounds()[1], + array([9.0, inf, 11.0])) + + def test_get_bounds_vector_nan_vector(self): + data_source = self.vector_data_source + self.vector_data[:, 2] = NaN + data_source.data = self.vector_data + + self.validate_data(data_source, self.vector_data) + with self.assertRaises(ValueError): + data_source.get_bounds() + + def test_metadata_changed(self): + with self.assertTraitChanges(self.data_source, 'metadata_changed', + count=1): + self.data_source.metadata = {'new_metadata': True} + + def test_metadata_items_changed(self): + with self.assertTraitChanges(self.data_source, 'metadata_changed', + count=1): + self.data_source.metadata['new_metadata'] = True + + def test_get_data_not_implemented(self): + data_source = BaseDataSource() + + self.check_invalid(data_source) + + def test_get_data_not_implemented_valid_data(self): + data_source = BaseDataSource(_data_valid=True) + + with self.assertRaises(NotImplementedError): + data_source.get_data() + + with self.assertRaises(NotImplementedError): + data_source.get_data_mask() + + with self.assertRaises(NotImplementedError): + data_source.is_masked() + + with self.assertRaises(NotImplementedError): + data_source.get_shape() + + with self.assertRaises(NotImplementedError): + data_source.get_size() + + with self.assertRaises(NotImplementedError): + data_source.get_bounds() + + #### Common validation methods ########################################### + + def validate_data(self, data_source, expected_data): + expected_shape = expected_data.shape[:data_source.dimension] + + axes = tuple(range(1, len(expected_data.shape))) + try: + expected_mask = isfinite(expected_data).all(axis=axes) + except TypeError: + expected_mask = ones(expected_data.shape, dtype=bool) + expected_is_masked = not expected_mask.all() + + # check that get_data() works + data = data_source.get_data() + + self.assertTrue(data_source._get_data_unsafe.called) + assert_array_equal(data, expected_data) + + # check that is_masked() works + data_source._get_data_unsafe.reset_mock() + self.assertEqual(data_source.is_masked(), expected_is_masked) + self.assertTrue(data_source._get_data_unsafe.called) + + # check that get_data_mask() works + data_source._get_data_unsafe.reset_mock() + data, mask = data_source.get_data_mask() + + assert_array_equal(data, expected_data) + assert_array_equal(mask, expected_mask) + self.assertTrue(data_source._get_data_unsafe.called) + + # check that get_shape() works + data_source._get_data_unsafe.reset_mock() + self.assertEquals(data_source.get_size(), expected_shape) + self.assertTrue(data_source._get_data_unsafe.called) + + # check that get_size() works + data_source._get_data_unsafe.reset_mock() + self.assertEquals(data_source.get_size(), prod(expected_shape)) + self.assertTrue(data_source._get_data_unsafe.called) + + def check_invalid(self, data_source): + # check that methods to get data fail + with self.assertRaises(DataInvalidError): + data_source.get_data() + + with self.assertRaises(DataInvalidError): + data_source.get_data_mask() + + with self.assertRaises(DataInvalidError): + data_source.is_masked() + + with self.assertRaises(DataInvalidError): + data_source.get_shape() + + with self.assertRaises(DataInvalidError): + data_source.get_size() + + with self.assertRaises(DataInvalidError): + data_source.get_bounds() From b879647c7d900e2d21fd2f28713962340bc61708 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Mon, 5 Jan 2015 11:21:31 +0000 Subject: [PATCH 30/30] Add PointTrait back for backwards compatibility. --- chaco/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chaco/base.py b/chaco/base.py index bbd108c04..8dea2392f 100644 --- a/chaco/base.py +++ b/chaco/base.py @@ -8,8 +8,8 @@ from math import radians, sqrt # Major library imports -from numpy import (array, argsort, concatenate, column_stack, cos, dot, empty, - nonzero, pi, sin, take, ndarray, number) +from numpy import (array, argsort, concatenate, cos, dot, empty, nonzero, pi, + sin, take, ndarray) # Enthought library imports from traits.api import ArrayOrNone, Either, Enum @@ -36,6 +36,7 @@ class DataBoundsError(ValueError): # A sequence of pairs of numbers, i.e., an Nx2 array. PointSequenceTrait = ArrayOrNone(shape=(None, 2), value=empty(shape=(0, 2))) +PointTrait = PointSequenceTrait # An NxM array of numbers. ScalarImageTrait = ArrayOrNone(shape=(None, None), value=empty(shape=(0, 0)))