diff --git a/src/benchmark/executions.py b/src/benchmark/executions.py index 681e6c78..c3039580 100644 --- a/src/benchmark/executions.py +++ b/src/benchmark/executions.py @@ -16,6 +16,7 @@ import atheris import contextlib +from typing import Any, Callable, Iterator, List, Tuple, TypeVar # Use new atheris instrumentation only if on new atheris if "instrument_func" in dir(atheris): @@ -23,15 +24,16 @@ instrument_imports = atheris.instrument_imports instrument_all = atheris.instrument_all else: + T = TypeVar("T") - def instrument_func(x): + def instrument_func(x: Callable[..., T]) -> Callable[..., T]: return x def instrument_all(): pass @contextlib.contextmanager - def instrument_imports(*args, **kwargs): + def instrument_imports(*args: Any, **kwargs: Any) -> Iterator[None]: yield None @@ -46,14 +48,14 @@ def instrument_imports(*args, **kwargs): import io -def _set_nonblocking(fd): +def _set_nonblocking(fd: int): """Set the specified fd to a nonblocking mode.""" oflags = fcntl.fcntl(fd, fcntl.F_GETFL) nflags = oflags | os.O_NONBLOCK fcntl.fcntl(fd, fcntl.F_SETFL, nflags) -def _benchmark_child(test_one_input, num_runs, pipe, args, inst_all): +def _benchmark_child(test_one_input: Callable[[bytes], None], num_runs: int, pipe: Tuple[int, int], args: List[str], inst_all: bool): os.close(pipe[0]) os.dup2(pipe[1], 1) os.dup2(pipe[1], 2) @@ -64,7 +66,7 @@ def _benchmark_child(test_one_input, num_runs, pipe, args, inst_all): counter = [0] start = time.time() - def wrapped_test_one_input(data): + def wrapped_test_one_input(data: bytes): counter[0] += 1 if counter[0] == num_runs: print(f"\nbenchmark_duration={time.time() - start}") @@ -76,11 +78,11 @@ def wrapped_test_one_input(data): assert False # Does not return -def run_benchmark(test_one_input, - num_runs, - timeout=10, - inst_all=False, - args=[]): +def run_benchmark(test_one_input: Callable[[bytes], None], + num_runs: int, + timeout: float = 10, + inst_all: bool = False, + args: List[str] = []): """Fuzz test_one_input() in a subprocess. This forks a child, and in the child, runs atheris.Setup(test_one_input) and @@ -140,7 +142,7 @@ def run_benchmark(test_one_input, @instrument_func -def low_cyclomatic(data): +def low_cyclomatic(data: bytes): x = 0 x = 1 x = 2 @@ -244,7 +246,7 @@ def low_cyclomatic(data): @instrument_func -def high_cyclomatic(data): +def high_cyclomatic(data: bytes): for c in data: if c == 0: c = 38 @@ -760,7 +762,7 @@ def high_cyclomatic(data): c = 7 -def json_fuzz(data): +def json_fuzz(data: bytes): try: json.loads(data.decode("utf-8", "surrogatepass")) except Exception as e: @@ -768,7 +770,7 @@ def json_fuzz(data): @instrument_func -def zip_fuzz(data): +def zip_fuzz(data: bytes): try: with io.BytesIO(data) as f: pz = zipfile.ZipFile(f) diff --git a/src/coverage_g3test.py b/src/coverage_g3test.py index 50d6af1d..62468202 100644 --- a/src/coverage_g3test.py +++ b/src/coverage_g3test.py @@ -17,6 +17,7 @@ import dis import re import unittest +from typing import Any, Tuple from unittest import mock with atheris.instrument_imports(): @@ -29,7 +30,7 @@ @atheris.instrument_func -def if_func(a): +def if_func(a: float) -> int: x = a if x: return 2 @@ -38,89 +39,89 @@ def if_func(a): @atheris.instrument_func -def cmp_less(a, b): +def cmp_less(a: float, b: float): return a < b @atheris.instrument_func -def cmp_greater(a, b): +def cmp_greater(a: float, b: float): return a > b @atheris.instrument_func -def cmp_equal_nested(a, b, c): +def cmp_equal_nested(a: float, b: float, c: float) -> bool: return (a == b) == c @atheris.instrument_func -def cmp_const_less(a): +def cmp_const_less(a: float) -> bool: return 1 < a @atheris.instrument_func -def cmp_const_less_inverted(a): +def cmp_const_less_inverted(a: float) -> bool: return a < 1 @atheris.instrument_func -def decorator_instrumented(x): +def decorator_instrumented(x: int): return 2 * x @atheris.instrument_func -def while_loop(a): +def while_loop(a: float): while a: a -= 1 @atheris.instrument_func -def regex_match(re_obj, a): +def regex_match(re_obj: re.Pattern, a: str): re_obj.match(a) @atheris.instrument_func -def starts_with(s, prefix): +def starts_with(s: str, prefix: str): s.startswith(prefix) @atheris.instrument_func -def ends_with(s, suffix): +def ends_with(s: str, suffix: str): s.endswith(suffix) # Verifying that no tracing happens when var args are passed in to # startswith method calls @atheris.instrument_func -def starts_with_var_args(s, *args): +def starts_with_var_args(s: str, *args: Any): s.startswith(*args) # Verifying that no tracing happens when var args are passed in to # endswith method calls @atheris.instrument_func -def ends_with_var_args(s, *args): +def ends_with_var_args(s: str, *args: Any): s.startswith(*args) class FakeStr: - def startswith(self, s, prefix): + def startswith(self, s: str, prefix: str): pass - def endswith(self, s, suffix): + def endswith(self, s: str, suffix: str): pass # Verifying that even though this code gets patched, no tracing happens @atheris.instrument_func -def fake_starts_with(s, prefix): +def fake_starts_with(s: str, prefix: str): fake_str = FakeStr() fake_str.startswith(s=s, prefix=prefix) # Verifying that even though this code gets patched, no tracing happens @atheris.instrument_func -def fake_ends_with(s, suffix): +def fake_ends_with(s: str, suffix: str): fake_str = FakeStr() fake_str.endswith(s, suffix) @@ -161,16 +162,16 @@ def multi_instrumented(x): @mock.patch.object(atheris, "_trace_branch") class CoverageTest(unittest.TestCase): - def testImport(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testImport(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_cmp_mock.side_effect = original_trace_cmp trace_branch_mock.assert_not_called() Sequence.load(b"0\0") trace_branch_mock.assert_called() - def testBranch(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testBranch(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.assert_not_called() if_func(True) first_call_set = trace_branch_mock.call_args_list @@ -188,14 +189,14 @@ def testBranch(self, trace_branch_mock, trace_cmp_mock, self.assertNotEqual(first_call_set, third_call_set) def testWhile( - self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock + self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, trace_regex_match_mock: mock.MagicMock ): trace_branch_mock.assert_not_called() while_loop(1) trace_branch_mock.assert_called() - def testRegex(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testRegex(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.reset_mock() trace_branch_mock.assert_not_called() trace_regex_match_mock.assert_not_called() @@ -204,7 +205,7 @@ def testRegex(self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock.assert_called() def testStrMethods( - self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock + self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, trace_regex_match_mock: mock.MagicMock ): trace_branch_mock.assert_not_called() trace_regex_match_mock.assert_not_called() @@ -252,7 +253,7 @@ def testStrMethods( trace_regex_match_mock.assert_not_called() trace_regex_match_mock.reset_mock() - def assertTraceCmpWas(self, call_args, left, right, op, left_is_const): + def assertTraceCmpWas(self, call_args: Tuple[int, int, int, int, bool], left: int, right: int, op: str, left_is_const: bool): """Compare a _trace_cmp call to expected values.""" # call_args: tuple(left, right, opid, idx, left_is_const) self.assertEqual(call_args[0], left) @@ -260,8 +261,8 @@ def assertTraceCmpWas(self, call_args, left, right, op, left_is_const): self.assertEqual(dis.cmp_op[call_args[2]], op) self.assertEqual(call_args[4], left_is_const) - def testCompare(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testCompare(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_cmp_mock.side_effect = original_trace_cmp self.assertTrue(cmp_less(1, 2)) @@ -297,8 +298,8 @@ def testCompare(self, trace_branch_mock, trace_cmp_mock, self.assertNotEqual(second_cmp_idx, fifth_cmp_idx) self.assertNotEqual(fourth_cmp_idx, fifth_cmp_idx) - def testConstCompare(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testConstCompare(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_cmp_mock.side_effect = original_trace_cmp self.assertTrue(cmp_const_less(2)) @@ -309,8 +310,8 @@ def testConstCompare(self, trace_branch_mock, trace_cmp_mock, self.assertTraceCmpWas(trace_cmp_mock.call_args[0], 1, 3, ">", True) trace_cmp_mock.reset_mock() - def testInstrumentationAppliedOnce(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testInstrumentationAppliedOnce(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.assert_not_called() multi_instrumented(7) trace_branch_mock.assert_called_once() diff --git a/src/coverage_test.py b/src/coverage_test.py index 7f55e570..b90c9558 100644 --- a/src/coverage_test.py +++ b/src/coverage_test.py @@ -15,6 +15,7 @@ import dis import re +from typing import Tuple import unittest from unittest import mock @@ -30,7 +31,7 @@ @atheris.instrument_func -def decorator_instrumented(x): +def decorator_instrumented(x: int) -> int: return 2 * x @@ -39,7 +40,7 @@ def decorator_instrumented(x): @atheris.instrument_func @atheris.instrument_func @atheris.instrument_func -def multi_instrumented(x): +def multi_instrumented(x: int) -> int: return 2 * x @@ -51,8 +52,8 @@ def multi_instrumented(x): @mock.patch.object(atheris, "_trace_branch") class CoverageTest(unittest.TestCase): - def testBasicBlock(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testBasicBlock(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.assert_not_called() coverage_test_helper.simple_func(7) trace_branch_mock.assert_called() @@ -61,8 +62,8 @@ def testBasicBlock(self, trace_branch_mock, trace_cmp_mock, coverage_test_helper.simple_func(2) trace_branch_mock.assert_called() - def testDecoratorBasicBlock(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testDecoratorBasicBlock(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.assert_not_called() decorator_instrumented(7) trace_branch_mock.assert_called() @@ -71,8 +72,8 @@ def testDecoratorBasicBlock(self, trace_branch_mock, trace_cmp_mock, decorator_instrumented(2) trace_branch_mock.assert_called() - def testBranch(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testBranch(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.assert_not_called() coverage_test_helper.if_func(True) first_call_set = trace_branch_mock.call_args_list @@ -90,14 +91,14 @@ def testBranch(self, trace_branch_mock, trace_cmp_mock, self.assertNotEqual(first_call_set, third_call_set) def testWhile( - self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock + self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, trace_regex_match_mock: mock.MagicMock ): trace_branch_mock.assert_not_called() coverage_test_helper.while_loop(1) trace_branch_mock.assert_called() - def testRegex(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testRegex(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.reset_mock() trace_branch_mock.assert_not_called() trace_regex_match_mock.assert_not_called() @@ -106,7 +107,7 @@ def testRegex(self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock.assert_called() def testStrMethods( - self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock + self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, trace_regex_match_mock: mock.MagicMock ): trace_branch_mock.assert_not_called() trace_regex_match_mock.assert_not_called() @@ -160,7 +161,7 @@ def testStrMethods( trace_regex_match_mock.assert_not_called() trace_regex_match_mock.reset_mock() - def assertTraceCmpWas(self, call_args, left, right, op, left_is_const): + def assertTraceCmpWas(self, call_args: Tuple[int, int, int, int, bool], left: int, right: int, op: str, left_is_const: bool): """Compare a _trace_cmp call to expected values.""" # call_args: tuple(left, right, opid, idx, left_is_const) self.assertEqual(call_args[0], left) @@ -168,8 +169,8 @@ def assertTraceCmpWas(self, call_args, left, right, op, left_is_const): self.assertEqual(dis.cmp_op[call_args[2]], op) self.assertEqual(call_args[4], left_is_const) - def testCompare(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testCompare(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_cmp_mock.side_effect = original_trace_cmp self.assertTrue(coverage_test_helper.cmp_less(1, 2)) @@ -205,8 +206,8 @@ def testCompare(self, trace_branch_mock, trace_cmp_mock, self.assertNotEqual(second_cmp_idx, fifth_cmp_idx) self.assertNotEqual(fourth_cmp_idx, fifth_cmp_idx) - def testConstCompare(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testConstCompare(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_cmp_mock.side_effect = original_trace_cmp self.assertTrue(coverage_test_helper.cmp_const_less(2)) @@ -217,8 +218,8 @@ def testConstCompare(self, trace_branch_mock, trace_cmp_mock, self.assertTraceCmpWas(trace_cmp_mock.call_args[0], 1, 3, ">", True) trace_cmp_mock.reset_mock() - def testInstrumentationAppliedOnce(self, trace_branch_mock, trace_cmp_mock, - trace_regex_match_mock): + def testInstrumentationAppliedOnce(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, + trace_regex_match_mock: mock.MagicMock): trace_branch_mock.assert_not_called() multi_instrumented(7) trace_branch_mock.assert_called_once() diff --git a/src/coverage_test_helper.py b/src/coverage_test_helper.py index fd7b8964..5b695fd5 100644 --- a/src/coverage_test_helper.py +++ b/src/coverage_test_helper.py @@ -13,12 +13,15 @@ # limitations under the License. """A helper library for coverage_test.py - coverage is added to this library.""" +import re +from typing import Any -def simple_func(a): + +def simple_func(a: float) -> float: return 2 * a -def if_func(a): +def if_func(a: float) -> int: x = a if x: return 2 @@ -26,72 +29,72 @@ def if_func(a): return 3 -def cmp_less(a, b): +def cmp_less(a: float, b: float) -> bool: return a < b -def cmp_greater(a, b): +def cmp_greater(a: float, b: float) -> bool: return a > b -def cmp_equal_nested(a, b, c): +def cmp_equal_nested(a: float, b: float, c: float) -> bool: return (a == b) == c -def cmp_const_less(a): +def cmp_const_less(a: float) -> bool: return 1 < a -def cmp_const_less_inverted(a): +def cmp_const_less_inverted(a: float) -> bool: return a < 1 -def while_loop(a): +def while_loop(a: float): while a: a -= 1 -def regex_match(re_obj, a): +def regex_match(re_obj: re.Pattern, a: str): re_obj.match(a) -def starts_with(s, prefix): +def starts_with(s: str, prefix: str): s.startswith(prefix) -def ends_with(s, suffix): +def ends_with(s: str, suffix: str): s.endswith(suffix) # Verifying that no tracing happens when var args are passed in to # startswith method calls -def starts_with_var_args(s, *args): +def starts_with_var_args(s: str, *args: Any): s.startswith(*args) # Verifying that no tracing happens when var args are passed in to # endswith method calls -def ends_with_var_args(s, *args): +def ends_with_var_args(s: str, *args: Any): s.startswith(*args) class FakeStr: - def startswith(self, s, prefix): + def startswith(self, s: str, prefix: str): pass - def endswith(self, s, suffix): + def endswith(self, s: str, suffix: str): pass # Verifying that even though this code gets patched, no tracing happens -def fake_starts_with(s, prefix): +def fake_starts_with(s: str, prefix: str): fake_str = FakeStr() fake_str.startswith(s=s, prefix=prefix) # Verifying that even though this code gets patched, no tracing happens -def fake_ends_with(s, suffix): +def fake_ends_with(s: str, suffix: str): fake_str = FakeStr() fake_str.endswith(s, suffix) diff --git a/src/custom_crossover_fuzz_test.py b/src/custom_crossover_fuzz_test.py index 780e6611..4ffdd1c6 100644 --- a/src/custom_crossover_fuzz_test.py +++ b/src/custom_crossover_fuzz_test.py @@ -22,20 +22,20 @@ import fuzz_test_lib # pytype: disable=import-error -def concatenate_crossover(data1, data2, max_size, seed): +def concatenate_crossover(data1: bytes, data2: bytes, max_size: int, seed: int) -> bytes: res = data1 + b"|" + data2 if max_size < len(res): return data1 return res -def noop_crossover(data1, data2, max_size, seed): +def noop_crossover(data1: bytes, data2: bytes, max_size: int, seed: int) -> bytes: print("Hello from crossover") return data1 @atheris.instrument_func -def bytes_comparison(data): +def bytes_comparison(data: bytes): if data == b"a|b|c|d|e": raise RuntimeError("Was a|b|c|d|e") diff --git a/src/custom_mutator_and_crossover_fuzz_test.py b/src/custom_mutator_and_crossover_fuzz_test.py index b9f5cacc..c4e1c00f 100644 --- a/src/custom_mutator_and_crossover_fuzz_test.py +++ b/src/custom_mutator_and_crossover_fuzz_test.py @@ -22,19 +22,19 @@ import fuzz_test_lib # pytype: disable=import-error -def noop_mutator(data, max_size, seed): +def noop_mutator(data: bytes, max_size: int, seed: int) -> bytes: print("Hello from mutator") res = atheris.Mutate(data, len(data)) return res -def noop_crossover(data1, data2, max_size, seed): +def noop_crossover(data1: bytes, data2: bytes, max_size: int, seed: int) -> bytes: print("Hello from crossover") return data1 + data2 @atheris.instrument_func -def test_one_input(data): +def test_one_input(data: bytes): if data == b"AA": raise ("Solved!") diff --git a/src/custom_mutator_fuzz_test.py b/src/custom_mutator_fuzz_test.py index eeab9c07..ada450a1 100644 --- a/src/custom_mutator_fuzz_test.py +++ b/src/custom_mutator_fuzz_test.py @@ -22,7 +22,7 @@ import fuzz_test_lib # pytype: disable=import-error -def compressed_mutator(data, max_size, seed): +def compressed_mutator(data: bytes, max_size: int, seed: int) -> bytes: try: decompressed = zlib.decompress(data) except zlib.error: @@ -33,7 +33,7 @@ def compressed_mutator(data, max_size, seed): @atheris.instrument_func -def compressed_data(data): +def compressed_data(data: bytes): try: decompressed = zlib.decompress(data) except zlib.error: diff --git a/src/function_hooks.py b/src/function_hooks.py index 7dccfc12..4461432d 100644 --- a/src/function_hooks.py +++ b/src/function_hooks.py @@ -345,7 +345,7 @@ def _trace_str( # pylint: enable=g-import-not-at-top -def _hook_str(*args, **kwargs) -> bool: +def _hook_str(*args: Any, **kwargs: Any) -> bool: """Proxy routing str functions through Atheris tracing. Even though bytecode is modified for hooking the str methods, we use this diff --git a/src/fuzz_test.py b/src/fuzz_test.py index 85621b6b..e8910b7d 100644 --- a/src/fuzz_test.py +++ b/src/fuzz_test.py @@ -19,18 +19,19 @@ import time import unittest import zlib +from typing import Callable, NoReturn import atheris import fuzz_test_lib # pytype: disable=import-error -def fail_immediately(data): +def fail_immediately(data: bytes) -> NoReturn: raise RuntimeError("Failed immediately") @atheris.instrument_func -def many_branches(data): +def many_branches(data: bytes): if len(data) < 4: return if data[0] != 12: @@ -46,7 +47,7 @@ def many_branches(data): @atheris.instrument_func -def never_fail(data): +def never_fail(data: bytes): for d in data: if d == 0: pass @@ -57,18 +58,18 @@ def never_fail(data): @atheris.instrument_func -def raise_with_surrogates(data): +def raise_with_surrogates(data: bytes) -> NoReturn: raise RuntimeError("abc \ud927 def") @atheris.instrument_func -def bytes_comparison(data): +def bytes_comparison(data: bytes): if data == b"foobarbazbiz": raise RuntimeError("Was foobarbazbiz") @atheris.instrument_func -def string_comparison(data): +def string_comparison(data: bytes): try: if data.decode("utf-8") == "foobarbazbiz": raise RuntimeError("Was foobarbazbiz") @@ -77,7 +78,7 @@ def string_comparison(data): @atheris.instrument_func -def utf8_comparison(data): +def utf8_comparison(data: bytes): try: decoded = data.decode("utf-8") if decoded == "⾐∾ⶑ➠": @@ -87,7 +88,7 @@ def utf8_comparison(data): @atheris.instrument_func -def nested_utf8_comparison(data): +def nested_utf8_comparison(data: bytes): try: decoded = data.decode("utf-8") x = "sup" @@ -100,19 +101,19 @@ def nested_utf8_comparison(data): @atheris.instrument_func -def timeout_py(data): +def timeout_py(data: bytes): del data time.sleep(100000000) @atheris.instrument_func -def regex_match(data): +def regex_match(data: bytes): if re.search(b"(Sun|Mon)day", data) is not None: raise RuntimeError("Was RegEx Match") @atheris.instrument_func -def str_startswith(data): +def str_startswith(data: bytes): try: decoded = data.decode("utf-8") if decoded.startswith("foobar"): @@ -122,7 +123,7 @@ def str_startswith(data): @atheris.instrument_func -def str_endswith(data): +def str_endswith(data: bytes): try: decoded = data.decode("utf-8") if decoded.endswith("bazbiz"): @@ -132,7 +133,7 @@ def str_endswith(data): @atheris.instrument_func -def str_methods_combined(data): +def str_methods_combined(data: bytes): try: decoded = data.decode("utf-8") if decoded.startswith("foo") and decoded.endswith("bar"): @@ -142,7 +143,7 @@ def str_methods_combined(data): @atheris.instrument_func -def str_startswith_tuple_prefix(data): +def str_startswith_tuple_prefix(data: bytes): try: decoded = data.decode("utf-8") if decoded.startswith(("foobar", "hellohi", "supyo")): @@ -152,7 +153,7 @@ def str_startswith_tuple_prefix(data): @atheris.instrument_func -def str_endswith_tuple_suffix(data): +def str_endswith_tuple_suffix(data: bytes): try: decoded = data.decode("utf-8") if decoded.endswith(("bazbiz", "byebye", "cyalater")): @@ -162,7 +163,7 @@ def str_endswith_tuple_suffix(data): @atheris.instrument_func -def str_startswith_with_start_and_end(data): +def str_startswith_with_start_and_end(data: bytes): try: decoded = data.decode("utf-8") if decoded.startswith("hellohi", 10, 20): @@ -172,7 +173,7 @@ def str_startswith_with_start_and_end(data): @atheris.instrument_func -def str_endswith_with_start_and_end(data): +def str_endswith_with_start_and_end(data: bytes): try: decoded = data.decode("utf-8") if decoded.endswith("supyo", 5, 15): @@ -182,7 +183,7 @@ def str_endswith_with_start_and_end(data): @atheris.instrument_func -def compressed_data(data): +def compressed_data(data: bytes): try: decompressed = zlib.decompress(data) except zlib.error: @@ -199,25 +200,25 @@ def compressed_data(data): @atheris.instrument_func -def reserve_counter_after_fuzz_start(data): +def reserve_counter_after_fuzz_start(data: Callable[[bytes], None]): del data atheris._reserve_counter() @functools.lru_cache(maxsize=None) -def instrument_once(func): +def instrument_once(func: Callable[[bytes], None]): """Instruments func, and verifies that this is the first time.""" assert("__ATHERIS_INSTRUMENTED__" not in func.__code__.co_consts) atheris.instrument_func(func) assert("__ATHERIS_INSTRUMENTED__" in func.__code__.co_consts) -def foo(data): +def foo(data: bytes): if data == b"foobar": raise RuntimeError("Code instrumented at runtime.") -def runtime_instrument_code(data): +def runtime_instrument_code(data: bytes): instrument_once(foo) foo(data) diff --git a/src/fuzz_test_lib.py b/src/fuzz_test_lib.py index 6c07b940..cfb8b522 100644 --- a/src/fuzz_test_lib.py +++ b/src/fuzz_test_lib.py @@ -17,19 +17,20 @@ import signal import sys import time +from typing import Any, Callable, Dict, List, Optional, Tuple import atheris -def _set_nonblocking(fd): +def _set_nonblocking(fd: int): """Set the specified fd to a nonblocking mode.""" oflags = fcntl.fcntl(fd, fcntl.F_GETFL) nflags = oflags | os.O_NONBLOCK fcntl.fcntl(fd, fcntl.F_SETFL, nflags) -def _fuzztest_child(test_one_input, custom_setup, setup_kwargs, pipe, args, - enabled_hooks): +def _fuzztest_child(test_one_input: Callable[[bytes], None], custom_setup: Callable[..., Any], setup_kwargs: Optional[Dict[str, str]], pipe: Tuple[int, int], args: Optional[List[str]], + enabled_hooks: Optional[List[str]]): """Fuzzing target to run as a separate process.""" os.close(pipe[0]) os.dup2(pipe[1], 1) @@ -61,13 +62,13 @@ def _fuzztest_child(test_one_input, custom_setup, setup_kwargs, pipe, args, os._exit(0) -def run_fuzztest(test_one_input, - custom_setup=None, - setup_kwargs=None, - expected_output=None, - timeout=10, - args=None, - enabled_hooks=None): +def run_fuzztest(test_one_input: Callable[[bytes], Any], + custom_setup: Optional[Callable[..., Any]] = None, + setup_kwargs: Optional[Dict[str, str]] = None, + expected_output: Optional[bytes] = None, + timeout: float = 10, + args: Optional[List[str]] = None, + enabled_hooks: Optional[List[str]] = None): """Fuzz test_one_input() in a subprocess. This forks a child, and in the child, runs atheris.Setup(test_one_input) and diff --git a/src/fuzzed_data_provider_test.py b/src/fuzzed_data_provider_test.py index 722df76b..6bc5094e 100644 --- a/src/fuzzed_data_provider_test.py +++ b/src/fuzzed_data_provider_test.py @@ -24,7 +24,7 @@ if sys.version_info[0] >= 3: codepoint = chr - def to_bytes(n, length): + def to_bytes(n: int, length: int) -> bytes: return n.to_bytes(length, "little") else: @@ -32,7 +32,7 @@ def to_bytes(n, length): codepoint = unichr # noqa: F821 # functionality from python3's int.to_bytes() - def to_bytes(n, length): + def to_bytes(n: int, length: int) -> str: h = "%x" % n s = ("0" * (len(h) % 2) + h).zfill(length * 2).decode("hex") return s[::-1] diff --git a/src/instrument_bytecode.py b/src/instrument_bytecode.py index b252588a..a86c72a8 100644 --- a/src/instrument_bytecode.py +++ b/src/instrument_bytecode.py @@ -17,6 +17,7 @@ Mainly the function patch_code(), which can instrument a code object and the helper class Instrumentor. """ +import ast import collections import dis import gc @@ -24,7 +25,7 @@ import logging import sys import types -from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union from . import utils from .native import _reserve_counter # type: ignore[attr-defined] @@ -95,7 +96,7 @@ def __init__( opcode: int, arg: int = 0, min_size: int = 0, - positions=None, + positions: Optional[List[ast.AST]] = None, ): self.lineno = lineno self.offset = offset @@ -327,12 +328,12 @@ def __init__(self, code: types.CodeType): self._build_cfg() self._check_state() - def _insert_instruction(self, to_insert, lineno, offset, opcode, arg=0): + def _insert_instruction(self, to_insert: List[Instruction], lineno: int, offset: int, opcode: int, arg: int = 0) -> int: to_insert.append(Instruction(lineno, offset, opcode, arg)) offset += to_insert[-1].get_size() return self._insert_instructions(to_insert, lineno, offset, caches(opcode)) - def _insert_instructions(self, to_insert, lineno, offset, tuples): + def _insert_instructions(self, to_insert: List[Instruction], lineno: int, offset: int, tuples: List[Sequence[int]]) -> int: for t in tuples: offset = self._insert_instruction(to_insert, lineno, offset, t[0], t[1]) return offset diff --git a/src/utils.py b/src/utils.py index a0bd0937..242337b6 100644 --- a/src/utils.py +++ b/src/utils.py @@ -15,6 +15,7 @@ import sys import os +from typing import IO def path() -> str: @@ -26,7 +27,7 @@ def path() -> str: class ProgressRenderer: """Displays an updating progress meter in the terminal.""" - def __init__(self, stream, total_count: int): + def __init__(self, stream: IO[str], total_count: int): assert stream.isatty() self.stream = stream diff --git a/src/version_dependent.py b/src/version_dependent.py index 21c66c4c..0d361fba 100644 --- a/src/version_dependent.py +++ b/src/version_dependent.py @@ -30,10 +30,10 @@ """ import sys -import types import dis import opcode -from typing import List +import types +from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Tuple, Union PYTHON_VERSION = sys.version_info[:2] @@ -165,7 +165,7 @@ def rel_reference_scale(opname: str) -> int: if (3, 6) <= PYTHON_VERSION <= (3, 7): def get_code_object( - code_obj, stacksize, bytecode, consts, names, lnotab, exceptiontable + code_obj: types.CodeType, stacksize: int, bytecode: bytes, consts: Tuple[str, ...], names: Tuple[str, ...], lnotab: bytes, exceptiontable: bytes ): return types.CodeType(code_obj.co_argcount, code_obj.co_kwonlyargcount, code_obj.co_nlocals, stacksize, code_obj.co_flags, @@ -177,7 +177,7 @@ def get_code_object( elif (3, 8) <= PYTHON_VERSION <= (3, 10): def get_code_object( - code_obj, stacksize, bytecode, consts, names, lnotab, exceptiontable + code_obj: types.CodeType, stacksize: int, bytecode: bytes, consts: Tuple[str, ...], names: Tuple[str, ...], lnotab: bytes, exceptiontable: bytes ): return types.CodeType(code_obj.co_argcount, code_obj.co_posonlyargcount, code_obj.co_kwonlyargcount, code_obj.co_nlocals, @@ -189,7 +189,7 @@ def get_code_object( else: def get_code_object( - code_obj, stacksize, bytecode, consts, names, lnotab, exceptiontable + code_obj: types.CodeType, stacksize: int, bytecode: bytes, consts: Tuple[str, ...], names: Tuple[str, ...], lnotab: bytes, exceptiontable: bytes ): return types.CodeType( code_obj.co_argcount, @@ -241,7 +241,7 @@ def add_bytes_to_jump_arg(arg: int, size: int) -> int: if (3, 6) <= PYTHON_VERSION <= (3, 9): - def get_lnotab(code, listing): + def get_lnotab(code: types.CodeType, listing: List) -> bytes: """Returns line number table.""" lnotab = [] current_lineno = listing[0].lineno @@ -296,7 +296,7 @@ def get_lnotab(code, listing): elif (3, 10) <= PYTHON_VERSION <= (3, 10): - def get_lnotab(code, listing): + def get_lnotab(code: types.CodeType, listing: List) -> bytes: """Returns line number table.""" lnotab = [] prev_lineno = listing[0].lineno @@ -330,7 +330,7 @@ def get_lnotab(code, listing): elif (3, 11) <= PYTHON_VERSION <= (3, 11): from .native import _generate_codetable # pytype: disable=import-error - def get_lnotab(code, listing): + def get_lnotab(code: types.CodeType, listing: List) -> bytes: ret = _generate_codetable(code, listing) return ret @@ -338,7 +338,7 @@ def get_lnotab(code, listing): class ExceptionTableEntry: - def __init__(self, start_offset, end_offset, target, depth, lasti): + def __init__(self, start_offset: int, end_offset: int, target: int, depth: int, lasti: bool): self.start_offset = start_offset self.end_offset = end_offset self.target = target @@ -353,9 +353,10 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return ( - self.start_offset == other.start_offset + isinstance(other, ExceptionTableEntry) + and self.start_offset == other.start_offset and self.end_offset == other.end_offset and self.target == other.target and self.depth == other.depth @@ -374,7 +375,9 @@ def __repr__(self) -> str: def __str__(self) -> str: return "\n".join([repr(x) for x in self.entries]) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, ExceptionTable): + return False if len(self.entries) != len(other.entries): return False for i in range(len(self.entries)): @@ -384,26 +387,26 @@ def __eq__(self, other): # Default implementations # 3.11+ override these. -def generate_exceptiontable(original_code, exception_table_entries): +def generate_exceptiontable(original_code: types.CodeType, exception_table_entries: List[ExceptionTableEntry]) -> bytes: return b"" -def parse_exceptiontable(code): +def parse_exceptiontable(code: Union[types.CodeType, bytes]) -> ExceptionTable: return ExceptionTable([]) if (3, 11) <= PYTHON_VERSION <= (3, 11): from .native import _generate_exceptiontable # pytype: disable=import-error - def generate_exceptiontable(original_code, exception_table_entries): # noqa: F811 + def generate_exceptiontable(original_code: types.CodeType, exception_table_entries: List[ExceptionTableEntry]) -> bytes: # noqa: F811 return _generate_exceptiontable(original_code, exception_table_entries) - def parse_exceptiontable(co_exceptiontable): # noqa: F811 + def parse_exceptiontable(co_exceptiontable: Union[types.CodeType, bytes]) -> ExceptionTable: # noqa: F811 if isinstance(co_exceptiontable, types.CodeType): return parse_exceptiontable(co_exceptiontable.co_exceptiontable) # These functions taken from: # https://github.com/python/cpython/blob/main/Objects/exception_handling_notes.txt - def parse_varint(iterator): + def parse_varint(iterator: Iterator[int]) -> int: b = next(iterator) val = b & 63 while b & 64: @@ -412,7 +415,7 @@ def parse_varint(iterator): val |= b & 63 return val - def parse_exception_table(co_exceptiontable): + def parse_exception_table(co_exceptiontable: bytes) -> Iterator[tuple[int, int, int, int, bool]]: iterator = iter(co_exceptiontable) try: while True: @@ -441,15 +444,15 @@ def parse_exception_table(co_exceptiontable): if (3, 6) <= PYTHON_VERSION <= (3, 10): # There are no CACHE instructions in these versions, so return 0. - def cache_count(op): + def cache_count(op: Union[str, int]) -> int: return 0 # There are no CACHE instructions in these versions, so return empty list. - def caches(op): + def caches(op: int) -> List[Tuple[int, int]]: return [] # Rotate the top width_n instructions, shift_n times. - def rot_n(width_n: int, shift_n: int = 1): + def rot_n(width_n: int, shift_n: int = 1) -> List[Sequence[int]]: if shift_n != 1: return RuntimeError("rot_n not supported with shift_n!=1. (Support could be emulated if needed.)") @@ -478,11 +481,11 @@ def rot_n(width_n: int, shift_n: int = 1): return [(dis.opmap["ROT_N"], width_n)] # 3.11+ needs a null terminator for the argument list, but 3.10- does not. - def args_terminator(): + def args_terminator() -> List[Tuple[int, int]]: return [] # In 3.10-, all you need to call a function is CALL_FUNCTION. - def call(argc: int): + def call(argc: int) -> List[Tuple[int, int]]: return [(dis.opmap["CALL_FUNCTION"], argc)] # In 3.10-, each call pops 1 thing other than the arguments off the stack: @@ -493,19 +496,19 @@ def call(argc: int): if PYTHON_VERSION >= (3, 11): # The number of CACHE instructions that must go after the given instr. - def cache_count(op): + def cache_count(op: str | int) -> int: if isinstance(op, str): op = dis.opmap[op] return getattr(opcode, '_inline_cache_entries')[op] # Generate a list of CACHE instructions for the given instr. - def caches(op): + def caches(op: int) -> List[Tuple[int, int]]: cc = cache_count(op) return [(dis.opmap["CACHE"], 0)] * cc # Rotate the top width_n instructions, shift_n times. - def rot_n(width_n: int, shift_n: int = 1): + def rot_n(width_n: int, shift_n: int = 1) -> List[Sequence[int]]: ret = [] for j in range(shift_n): for i in range(width_n, 1, -1): @@ -514,11 +517,11 @@ def rot_n(width_n: int, shift_n: int = 1): # Calling a free function in 3.11 requires a null terminator for the # args list on the stack. - def args_terminator(): + def args_terminator() -> List[Tuple[int, int]]: return [(dis.opmap["PUSH_NULL"], 0)] # 3.11 requires a PRECALL instruction prior to every CALL instruction. - def call(argc: int): + def call(argc: int) -> List[Tuple[int, int]]: ret = [] ret.append((dis.opmap["PRECALL"], argc)) ret.append((dis.opmap["CALL"], argc)) @@ -533,13 +536,13 @@ def call(argc: int): if (3, 6) <= PYTHON_VERSION <= (3, 10): - def get_instructions(x, *, first_line=None): + def get_instructions(x: types.CodeType, *, first_line: Optional[int] = None) -> Iterator[dis.Instruction]: return dis.get_instructions(x, first_line=first_line) if (3, 11) <= PYTHON_VERSION: - def get_instructions(x, *, first_line=None, adaptive=False): + def get_instructions(x: types.CodeType, *, first_line: Optional[int] = None, adaptive: bool = False) -> Iterator[dis.Instruction]: return dis.get_instructions( x, first_line=first_line, adaptive=adaptive, show_caches=True )