Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions bootstrapped/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,38 @@ def _needs_sparse_unification(values_lists):
return False


def _is_sparse_or_array(x):
return isinstance(x, _sparse.csr_matrix) or isinstance(x, _np.ndarray)


def _validate_arrays(values_lists):
t = values_lists[0]
t_type = type(t)
if not isinstance(t, _sparse.csr_matrix) and not isinstance(t, _np.ndarray):
raise ValueError(('The arrays must either be of type '
'scipy.sparse.csr_matrix or numpy.array'))

for _, values in enumerate(values_lists[1:]):
if not isinstance(values, t_type):
raise ValueError('The arrays must all be of the same type')
if not all(map(_is_sparse_or_array, values_lists)):
raise TypeError(('The arrays must either be of type '
'scipy.sparse.csr_matrix or numpy.ndarray'))

types = {type(x) for x in values_lists}
shapes = {x.shape for x in values_lists}

if t.shape != values.shape:
raise ValueError('The arrays must all be of the same shape')
if len(types) != 1:
raise TypeError('The arrays must all be of the same type')

if isinstance(t, _sparse.csr_matrix):
if values.shape[0] > 1:
raise ValueError(('The sparse matrix must have shape 1 row X N'
' columns'))
if len(shapes) != 1:
raise ValueError('The arrays must all be of the same shape')

if isinstance(t, _sparse.csr_matrix):
shape = next(iter(shapes))

if types == {_sparse.csr_matrix}:
if shape[0] > 1:
raise ValueError(('The sparse matrix must have shape 1 row X N'
' columns'))
if _needs_sparse_unification(values_lists):
raise ValueError(('The non-zero entries in the sparse arrays'
' must be aligned: see '
'bootstrapped.unify_sparse_vectors function'))
else:
if len(shape) != 1:
raise ValueError('numpy.ndarray inputs must have .ndim of 1')


def _generate_distributions(values_lists, num_iterations):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,23 @@ def test_t_dist(self):
mr[1],
delta=mr[1] / 100.
)

def test_input_validation(self):
too_many_dims_numpy = np.ones(shape=(2, 2))
square_sparse = sparse.csr_matrix(too_many_dims_numpy)
self.assertRaises(
ValueError,
bs.bootstrap_ab,
too_many_dims_numpy,
too_many_dims_numpy,
bs_stats.mean,
bs_compare.percent_change
)
self.assertRaises(
ValueError,
bs.bootstrap_ab,
square_sparse,
square_sparse,
bs_stats.mean,
bs_compare.percent_change
)