diff --git a/pytest_split_tests/__init__.py b/pytest_split_tests/__init__.py index 9658bdd..f224b51 100644 --- a/pytest_split_tests/__init__.py +++ b/pytest_split_tests/__init__.py @@ -6,21 +6,23 @@ from _pytest.config import create_terminal_writer import pytest +def get_group_size_and_start(total_items, total_groups, group_id): + """Calculate group size and start index.""" + base_size = total_items // total_groups + rem = total_items % total_groups -def get_group_size(total_items, total_groups): - """Return the group size.""" - return int(math.ceil(float(total_items) / total_groups)) + start = base_size * (group_id - 1) + min(group_id - 1, rem) + size = base_size + 1 if group_id <= rem else base_size + return (start, size) -def get_group(items, group_size, group_id): +def get_group(items, total_groups, group_id): """Get the items from the passed in group based on group size.""" - start = group_size * (group_id - 1) - end = start + group_size - - if start >= len(items) or start < 0: + if not 0 < group_id <= total_groups: raise ValueError("Invalid test-group argument") - return items[start:end] + start, size = get_group_size_and_start(len(items), total_groups, group_id) + return items[start:start+size] def pytest_addoption(parser): @@ -40,8 +42,8 @@ def pytest_collection_modifyitems(session, config, items): yield group_count = config.getoption('test-group-count') group_id = config.getoption('test-group') - seed = config.getoption('random-seed', False) - prescheduled_path = config.getoption('prescheduled', None) + seed = config.getoption('random-seed') + prescheduled_path = config.getoption('prescheduled') if not group_count or not group_id: return @@ -70,14 +72,13 @@ def pytest_collection_modifyitems(session, config, items): if test_name in test_dict] unscheduled_tests = [item for item in items if item not in all_prescheduled_tests] - if seed is not False: + if seed is not None: seeded = Random(seed) seeded.shuffle(unscheduled_tests) total_unscheduled_items = len(unscheduled_tests) - group_size = get_group_size(total_unscheduled_items, group_count) - tests_in_group = get_group(unscheduled_tests, group_size, group_id) + tests_in_group = get_group(unscheduled_tests, group_count, group_id) items[:] = tests_in_group + prescheduled_tests items.sort(key=original_order.__getitem__) diff --git a/tests/test_groups.py b/tests/test_groups.py index 8cb9383..8f5a1e9 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -1,25 +1,25 @@ import pytest -from pytest_split_tests import get_group, get_group_size +from pytest_split_tests import get_group, get_group_size_and_start -def test_group_size_computed_correctly_for_even_group(): - expected = 8 - actual = get_group_size(32, 4) # 32 total tests; 4 groups +def test_group_params_computed_correctly_for_even_group(): + expected = [(0, 8), (8, 8), (16, 8), (24, 8)] + actual = [get_group_size_and_start(32, 4, group_id) for group_id in range(1, 5)] # 32 total tests; 4 groups assert expected == actual def test_group_size_computed_correctly_for_odd_group(): - expected = 8 - actual = get_group_size(31, 4) # 31 total tests; 4 groups + expected = [(0, 8), (8, 8), (16, 8), (24, 7)] + actual = [get_group_size_and_start(31, 4, group_id) for group_id in range(1, 5)] # 32 total tests; 4 groups assert expected == actual def test_group_is_the_proper_size(): items = [str(i) for i in range(32)] - group = get_group(items, 8, 1) + group = get_group(items, 4, 1) assert len(group) == 8 @@ -27,7 +27,7 @@ def test_group_is_the_proper_size(): def test_all_groups_together_form_original_set_of_tests(): items = [str(i) for i in range(32)] - groups = [get_group(items, 8, i) for i in range(1, 5)] + groups = [get_group(items, 4, i) for i in range(1, 5)] combined = [] for group in groups: @@ -40,11 +40,11 @@ def test_group_that_is_too_high_raises_value_error(): items = [str(i) for i in range(32)] with pytest.raises(ValueError): - get_group(items, 8, 5) + get_group(items, 4, 5) def test_group_that_is_too_low_raises_value_error(): items = [str(i) for i in range(32)] with pytest.raises(ValueError): - get_group(items, 8, 0) + get_group(items, 4, 0)