diff --git a/pytools/py_codegen.py b/pytools/py_codegen.py index a3f6e608..01943d89 100644 --- a/pytools/py_codegen.py +++ b/pytools/py_codegen.py @@ -52,10 +52,14 @@ def get_picklable_module(self, name=None): class PythonFunctionGenerator(PythonCodeGenerator): - def __init__(self, name, args): + def __init__(self, name, args, decorators=None): PythonCodeGenerator.__init__(self) self.name = name + if decorators: + for decorator in decorators: + self(decorator) + self("def {}({}):".format(name, ", ".join(args))) self.indent() diff --git a/pytools/test/test_py_codegen.py b/pytools/test/test_py_codegen.py index d2f0401a..5f7b280a 100644 --- a/pytools/test/test_py_codegen.py +++ b/pytools/test/test_py_codegen.py @@ -2,6 +2,8 @@ import sys +import pytest + import pytools import pytools.py_codegen as codegen @@ -30,6 +32,41 @@ def test_picklable_function(): assert f() == 1 +def test_function_decorators(capfd): + cg = codegen.PythonFunctionGenerator("f", args=(), decorators=["@staticmethod"]) + cg("return 42") + + assert cg.get_function()() == 42 + + cg = codegen.PythonFunctionGenerator("f", args=(), decorators=["@classmethod"]) + cg("return 42") + + with pytest.raises(TypeError): + cg.get_function()() + + cg = codegen.PythonFunctionGenerator("f", args=(), + decorators=["@staticmethod", "@classmethod"]) + cg("return 42") + + with pytest.raises(TypeError): + cg.get_function()() + + cg = codegen.PythonFunctionGenerator("f", args=("x"), + decorators=["from functools import lru_cache", "@lru_cache"]) + cg("print('Hello World!')") + cg("return 42") + + f = cg.get_function() + + assert f(0) == 42 + out, _err = capfd.readouterr() + assert out == "Hello World!\n" + + assert f(0) == 42 + out, _err = capfd.readouterr() + assert out == "" # second print is not executed due to lru_cache + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])