diff --git a/bootstrapped/bootstrap.py b/bootstrapped/bootstrap.py index 807aa4b..03a1289 100644 --- a/bootstrapped/bootstrap.py +++ b/bootstrapped/bootstrap.py @@ -114,30 +114,38 @@ def _needs_sparse_unification(values_lists): return False +def _is_sparse_or_array(x): + return isinstance(x, _sparse.csr_matrix) or isinstance(x, _np.ndarray) + + def _validate_arrays(values_lists): - t = values_lists[0] - t_type = type(t) - if not isinstance(t, _sparse.csr_matrix) and not isinstance(t, _np.ndarray): - raise ValueError(('The arrays must either be of type ' - 'scipy.sparse.csr_matrix or numpy.array')) - for _, values in enumerate(values_lists[1:]): - if not isinstance(values, t_type): - raise ValueError('The arrays must all be of the same type') + if not all(map(_is_sparse_or_array, values_lists)): + raise TypeError(('The arrays must either be of type ' + 'scipy.sparse.csr_matrix or numpy.ndarray')) + + types = {type(x) for x in values_lists} + shapes = {x.shape for x in values_lists} - if t.shape != values.shape: - raise ValueError('The arrays must all be of the same shape') + if len(types) != 1: + raise TypeError('The arrays must all be of the same type') - if isinstance(t, _sparse.csr_matrix): - if values.shape[0] > 1: - raise ValueError(('The sparse matrix must have shape 1 row X N' - ' columns')) + if len(shapes) != 1: + raise ValueError('The arrays must all be of the same shape') - if isinstance(t, _sparse.csr_matrix): + shape = next(iter(shapes)) + + if types == {_sparse.csr_matrix}: + if shape[0] > 1: + raise ValueError(('The sparse matrix must have shape 1 row X N' + ' columns')) if _needs_sparse_unification(values_lists): raise ValueError(('The non-zero entries in the sparse arrays' ' must be aligned: see ' 'bootstrapped.unify_sparse_vectors function')) + else: + if len(shape) != 1: + raise ValueError('numpy.ndarray inputs must have .ndim of 1') def _generate_distributions(values_lists, num_iterations): diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 89ed750..0abcefd 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -485,3 +485,23 @@ def test_t_dist(self): mr[1], delta=mr[1] / 100. ) + + def test_input_validation(self): + too_many_dims_numpy = np.ones(shape=(2, 2)) + square_sparse = sparse.csr_matrix(too_many_dims_numpy) + self.assertRaises( + ValueError, + bs.bootstrap_ab, + too_many_dims_numpy, + too_many_dims_numpy, + bs_stats.mean, + bs_compare.percent_change + ) + self.assertRaises( + ValueError, + bs.bootstrap_ab, + square_sparse, + square_sparse, + bs_stats.mean, + bs_compare.percent_change + )