Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 190 additions & 3 deletions smart_tests/commands/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from ..utils.env_keys import REPORT_ERROR_KEY
from ..utils.fail_fast_mode import (FailFastModeValidateParams, fail_fast_mode_validate,
set_fail_fast_mode, warn_and_exit_if_fail_fast_mode)
from ..utils.input_snapshot import InputSnapshotId
from ..utils.smart_tests_client import SmartTestsClient
from ..utils.typer_types import Duration, Percentage, parse_duration, parse_percentage
from ..utils.typer_types import Duration, Fraction, Percentage, parse_duration, parse_fraction, parse_percentage
from .test_path_writer import TestPathWriter


Expand Down Expand Up @@ -174,6 +175,23 @@ def __init__(
type=fileText(mode="r"),
metavar="FILE"
)] = None,
input_snapshot_id: Annotated[InputSnapshotId | None, InputSnapshotId.as_option()] = None,
print_input_snapshot_id: Annotated[bool, typer.Option(
"--print-input-snapshot-id",
help="Print the input snapshot ID returned from the server instead of the subset results"
)] = False,
bin_target: Annotated[Fraction | None, typer.Option(
"--bin",
help="Split subset into bins, e.g. --bin 1/4",
metavar="INDEX/COUNT",
type=parse_fraction
)] = None,
same_bin_files: Annotated[List[str], typer.Option(
"--same-bin",
help="Keep all tests listed in the file together when splitting; one test per line",
metavar="FILE",
multiple=True
)] = [],
is_get_tests_from_guess: Annotated[bool, typer.Option(
"--get-tests-from-guess",
help="Get subset list from guessed tests"
Expand Down Expand Up @@ -255,9 +273,15 @@ def warn(msg: str):
self.ignore_flaky_tests_above = ignore_flaky_tests_above
self.prioritize_tests_failed_within_hours = prioritize_tests_failed_within_hours
self.prioritized_tests_mapping_file = prioritized_tests_mapping_file
self.input_snapshot_id = input_snapshot_id.value if input_snapshot_id else None
self.print_input_snapshot_id = print_input_snapshot_id
self.bin_target = bin_target
self.same_bin_files = list(same_bin_files)
self.is_get_tests_from_guess = is_get_tests_from_guess
self.use_case = use_case

self._validate_print_input_snapshot_option()

self.file_path_normalizer = FilePathNormalizer(base_path, no_base_path_inference=no_base_path_inference)

self.test_paths: list[list[dict[str, str]]] = []
Expand Down Expand Up @@ -305,7 +329,7 @@ def stdin(self) -> Iterable[str]:
"""

# To avoid the cli continue to wait from stdin
if self.is_get_tests_from_previous_sessions or self.is_get_tests_from_guess:
if self._should_skip_stdin():
return []

if sys.stdin.isatty():
Expand Down Expand Up @@ -404,8 +428,103 @@ def get_payload(self) -> dict[str, Any]:
if self.use_case:
payload['changesUnderTest'] = self.use_case.value

if self.input_snapshot_id is not None:
payload['subsettingId'] = self.input_snapshot_id

split_subset = self._build_split_subset_payload()
if split_subset:
payload['splitSubset'] = split_subset

return payload

def _build_split_subset_payload(self) -> dict[str, Any] | None:
if self.bin_target is None:
if self.same_bin_files:
print_error_and_die(
"--same-bin option requires --bin option.\nPlease set --bin option to use --same-bin",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)
return None

slice_index = self.bin_target.numerator
slice_count = self.bin_target.denominator

if slice_index <= 0 or slice_count <= 0:
print_error_and_die(
"Invalid --bin value. Both index and count must be positive integers.",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)

if slice_count < slice_index:
print_error_and_die(
"Invalid --bin value. The numerator cannot exceed the denominator.",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)

same_bins = self._read_same_bin_files()

return {
"sliceIndex": slice_index,
"sliceCount": slice_count,
"sameBins": same_bins,
}

def _read_same_bin_files(self) -> list[list[TestPath]]:
if not self.same_bin_files:
return []

formatter = self.same_bin_formatter
if formatter is None:
print_error_and_die(
"--same-bin is not supported for this test runner.",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)

same_bins: list[list[TestPath]] = []
seen_tests: set[str] = set()

for same_bin_file in self.same_bin_files:
try:
with open(same_bin_file, "r", encoding="utf-8") as fp:
tests = [line.strip() for line in fp if line.strip()]
except OSError as exc:
print_error_and_die(
f"Failed to read --same-bin file '{same_bin_file}': {exc}",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)

unique_tests = list(dict.fromkeys(tests))

group: list[TestPath] = []
for test in unique_tests:
if test in seen_tests:
print_error_and_die(
f"Error: test '{test}' is listed in multiple --same-bin files.",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)
seen_tests.add(test)

# For type check
assert formatter is not None, "--same -bin is not supported for this test runner"
formatted = formatter(test)
if not formatted:
print_error_and_die(
f"Failed to parse test '{test}' from --same-bin file {same_bin_file}",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)
group.append(formatted)

same_bins.append(group)

return same_bins

def _collect_potential_test_files(self):
LOOSE_TEST_FILE_PATTERN = r'(\.(test|spec)\.|_test\.|Test\.|Spec\.|test/|tests/|__tests__/|src/test/)'
EXCLUDE_PATTERN = r'(BUILD|Makefile|Dockerfile|LICENSE|.gitignore|.gitkeep|.keep|id_rsa|rsa|blank|taglib)|\.(xml|json|jsonl|txt|yml|yaml|toml|md|png|jpg|jpeg|gif|svg|sql|html|css|graphql|proto|gz|zip|rz|bzl|conf|config|snap|pem|crt|key|lock|jpi|hpi|jelly|properties|jar|ini|mod|sum|bmp|env|envrc|sh)$' # noqa E501
Expand Down Expand Up @@ -463,13 +582,75 @@ def request_subset(self) -> SubsetResult:
e, "Warning: the service failed to subset. Falling back to running all tests")
return SubsetResult.from_test_paths(self.test_paths)

def _requires_test_input(self) -> bool:
return (
self.input_snapshot_id is None
and not self.is_get_tests_from_previous_sessions # noqa: W503
and len(self.test_paths) == 0 # noqa: W503
)

def _should_skip_stdin(self) -> bool:
if self.is_get_tests_from_previous_sessions or self.is_get_tests_from_guess:
return True

if self.input_snapshot_id is not None:
if not sys.stdin.isatty():
warn_and_exit_if_fail_fast_mode(
"Warning: --input-snapshot-id is set so stdin will be ignored."
)
return True
return False

def _validate_print_input_snapshot_option(self):
if not self.print_input_snapshot_id:
return

conflicts: list[str] = []
option_checks = [
("--target", self.target is not None),
("--time", self.time is not None),
("--confidence", self.confidence is not None),
("--goal-spec", self.goal_spec is not None),
("--rest", self.rest is not None),
("--bin", self.bin_target is not None),
("--same-bin", bool(self.same_bin_files)),
("--ignore-new-tests", self.ignore_new_tests),
("--ignore-flaky-tests-above", self.ignore_flaky_tests_above is not None),
("--prioritize-tests-failed-within-hours", self.prioritize_tests_failed_within_hours is not None),
("--prioritized-tests-mapping", self.prioritized_tests_mapping_file is not None),
("--get-tests-from-previous-sessions", self.is_get_tests_from_previous_sessions),
("--get-tests-from-guess", self.is_get_tests_from_guess),
("--output-exclusion-rules", self.is_output_exclusion_rules),
("--non-blocking", self.is_non_blocking),
]

for option_name, is_set in option_checks:
if is_set:
conflicts.append(option_name)

if conflicts:
conflict_list = ", ".join(conflicts)
print_error_and_die(
f"--print-input-snapshot-id cannot be used with {conflict_list}.",
self.tracking_client,
Tracking.ErrorEvent.USER_ERROR,
)

def _print_input_snapshot_id_value(self, subset_result: SubsetResult):
if not subset_result.subset_id:
raise click.ClickException(
"Subset request did not return an input snapshot ID. Please re-run the command."
)

click.echo(subset_result.subset_id)

def run(self):
"""called after tests are scanned to compute the optimized order"""

if self.is_get_tests_from_guess:
self._collect_potential_test_files()

if not self.is_get_tests_from_previous_sessions and len(self.test_paths) == 0:
if self._requires_test_input():
if self.input_given:
print_error_and_die("ERROR: Given arguments did not match any tests. They appear to be incorrect/non-existent.", tracking_client, Tracking.ErrorEvent.USER_ERROR) # noqa E501
else:
Expand All @@ -488,6 +669,12 @@ def run(self):

if len(subset_result.subset) == 0:
warn_and_exit_if_fail_fast_mode("Error: no tests found matching the path.")
if self.print_input_snapshot_id:
self._print_input_snapshot_id_value(subset_result)
return

if self.print_input_snapshot_id:
self._print_input_snapshot_id_value(subset_result)
return

# TODO(Konboi): split subset isn't provided for smart-tests initial release
Expand Down
8 changes: 4 additions & 4 deletions smart_tests/commands/test_path_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os.path import join
from typing import Callable, Dict, List
from typing import Callable, List

import click

Expand All @@ -19,7 +19,7 @@ class TestPathWriter(object):

def __init__(self, app: Application):
self.formatter = self.default_formatter
self._same_bin_formatter: Callable[[str], Dict[str, str]] | None = None
self._same_bin_formatter: Callable[[str], TestPath] | None = None
self.separator = "\n"
self.app = app

Expand All @@ -43,9 +43,9 @@ def print(self, test_paths: List[TestPath]):
for t in test_paths))

@property
def same_bin_formatter(self) -> Callable[[str], Dict[str, str]] | None:
def same_bin_formatter(self) -> Callable[[str], TestPath] | None:
return self._same_bin_formatter

@same_bin_formatter.setter
def same_bin_formatter(self, v: Callable[[str], Dict[str, str]]):
def same_bin_formatter(self, v: Callable[[str], TestPath]):
self._same_bin_formatter = v
1 change: 1 addition & 0 deletions smart_tests/test_runners/go_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def subset(client: Subset):
test_cases = []
client.formatter = lambda x: f"^{x[1]['name']}$"
client.separator = '|'
client.same_bin_formatter = format_same_bin
client.run()


Expand Down
2 changes: 2 additions & 0 deletions smart_tests/test_runners/gradle.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def exclusion_output_handler(subset_tests, rest_tests):
client.formatter = lambda x: f"--tests {x[0]['name']}"
client.separator = ' '

client.same_bin_formatter = lambda s: [{"type": "class", "name": s}]

client.run()


Expand Down
2 changes: 2 additions & 0 deletions smart_tests/test_runners/maven.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def file2test(f: str) -> List | None:
for root in source_roots:
client.scan(root, '**/*', file2test)

client.same_bin_formatter = lambda s: [{"type": "class", "name": s}]

client.run()


Expand Down
50 changes: 50 additions & 0 deletions smart_tests/utils/input_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Utility type for --input-snapshot-id option."""

import click

from smart_tests.args4p import typer


class InputSnapshotId:
"""Parses either a numeric snapshot ID or @path reference."""

def __init__(self, raw: str):
value = raw
if value.startswith('@'):
file_path = value[1:]
try:
with open(file_path, 'r', encoding='utf-8') as fp:
value = fp.read().strip()
except OSError as exc:
raise click.BadParameter(
f"Failed to read input snapshot ID file '{file_path}': {exc}"
)

try:
parsed = int(value)
except ValueError:
raise click.BadParameter(
f"Invalid input snapshot ID '{value}'. Expected a positive integer."
)

if parsed < 1:
raise click.BadParameter(
"Invalid input snapshot ID. Expected a positive integer."
)

self.value = parsed

def __int__(self) -> int:
return self.value

def __str__(self) -> str:
return str(self.value)

@staticmethod
def as_option() -> typer.Option:
return typer.Option(
"--input-snapshot-id",
help="Reuse reorder results from an existing input snapshot ID or specify @path/to/file to load it",
metavar="ID|@FILE",
type=InputSnapshotId,
)
Loading
Loading