diff --git a/src/litgen/internal/adapted_types/adapted_function.py b/src/litgen/internal/adapted_types/adapted_function.py index be137d4d..dc6942b9 100644 --- a/src/litgen/internal/adapted_types/adapted_function.py +++ b/src/litgen/internal/adapted_types/adapted_function.py @@ -549,6 +549,19 @@ def _pydef_end_arg_docstring_returnpolicy(self) -> str: replace_lines.maybe_keep_alive = self._pydef_fill_keep_alive_from_function_comment() replace_lines.maybe_call_guard = self._pydef_fill_call_guard_from_function_comment() + # Add gil_scoped_release call guard if regex matches + if code_utils.does_match_regex_or_matcher( + self.options.fn_add_gil_scoped_release_guard__regex, self.cpp_element().function_name + ): + py_ns = "py" if self.options.bind_library == BindLibraryType.pybind11 else "nb" + gil_guard = f"{py_ns}::call_guard<{py_ns}::gil_scoped_release>()" + if replace_lines.maybe_call_guard is None: + replace_lines.maybe_call_guard = gil_guard + else: + # If a call guard is already present from comments, append the new one. + # Note: pybind11 supports multiple call guards. + replace_lines.maybe_call_guard += f", {gil_guard}" + # Process template code = code_utils.process_code_template( input_string=template_code, diff --git a/src/litgen/options.py b/src/litgen/options.py index e241a6c2..0420e0fe 100644 --- a/src/litgen/options.py +++ b/src/litgen/options.py @@ -243,6 +243,13 @@ class LitgenOptions: # error: no matching function for call to object of type 'const detail::overload_cast_impl<...>' fn_force_lambda__regex: RegexOrMatcher = "" + # ------------------------------------------------------------------------------ + # Add py::call_guard() + # ------------------------------------------------------------------------------ + # Add a GIL release call guard for functions that matches these regexes. + # This is useful for long-running C++ functions that don't need to interact with Python objects. + fn_add_gil_scoped_release_guard__regex: RegexOrMatcher = "" + # ------------------------------------------------------------------------------ # C style buffers to py::array # ------------------------------------------------------------------------------ diff --git a/src/litgen/tests/internal/adapted_types/adapted_function_test.py b/src/litgen/tests/internal/adapted_types/adapted_function_test.py index d3887635..1fc76920 100644 --- a/src/litgen/tests/internal/adapted_types/adapted_function_test.py +++ b/src/litgen/tests/internal/adapted_types/adapted_function_test.py @@ -127,6 +127,38 @@ def test_return_policy_regex() -> None: """, ) +def test_gil_scoped_release_py() -> None: + options = LitgenOptions() + options.bind_library = litgen.BindLibraryType.pybind11 + options.fn_add_gil_scoped_release_guard__regex = r"^Foo" + + code = """ + void Foo(); + """ + generated_code = LitgenGeneratorTestsHelper.code_to_pydef(options, code) + expected_code = """ + m.def("foo", + Foo, py::call_guard()); + """ + # logging.warning("\n" + generated_code) + code_utils.assert_are_codes_equal(generated_code, expected_code) + +def test_gil_scoped_release_nb() -> None: + options = LitgenOptions() + options.bind_library = litgen.BindLibraryType.nanobind + options.fn_add_gil_scoped_release_guard__regex = r"^Foo" + + code = """ + void Foo(); + """ + generated_code = LitgenGeneratorTestsHelper.code_to_pydef(options, code) + expected_code = """ + m.def("foo", + Foo, nb::call_guard()); + """ + code_utils.assert_are_codes_equal(generated_code, expected_code) + + def test_implot_one_buffer() -> None: options = LitgenOptions()