-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathgen_interface.py
More file actions
120 lines (100 loc) · 4.57 KB
/
gen_interface.py
File metadata and controls
120 lines (100 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import inspect
import textwrap
import re
import itertools
import numbers
import importlib
import sys
import functools
from pathlib import Path
import shutil
from utils3d.helpers import suppress_traceback
def _contains_tensor(obj):
if isinstance(obj, (list, tuple)):
return any(_contains_tensor(item) for item in obj)
elif isinstance(obj, dict):
return any(_contains_tensor(value) for value in obj.values())
else:
import torch
return isinstance(obj, torch.Tensor)
@suppress_traceback
def _call_based_on_args(fname, args, kwargs):
if 'torch' in sys.modules:
if any(_contains_tensor(arg) for arg in args) or any(_contains_tensor(v) for v in kwargs.values()):
fn = getattr(utils3d.torch, fname, None)
if fn is None:
raise NotImplementedError(f"Function {fname} has no torch implementation.")
return fn(*args, **kwargs)
fn = getattr(utils3d.numpy, fname, None)
if fn is None:
raise NotImplementedError(f"Function {fname} has no numpy implementation.")
return fn(*args, **kwargs)
def get_signature(fn):
signature = inspect.signature(fn)
signature_str = str(signature)
signature_str = re.sub(r"<class '.*'>", lambda m: m.group(0).split('\'')[1], signature_str)
signature_str = re.sub(r"(?<!\.)numpy\.", "numpy_.", signature_str)
signature_str = re.sub(r"(?<!\.)torch\.", "torch_.", signature_str)
return signature_str
if __name__ == "__main__":
import utils3d.numpy, utils3d.torch
numpy_impl = utils3d.numpy
torch_impl = utils3d.torch
numpy_funcs = {name: getattr(numpy_impl, name, None) for name in numpy_impl.__all__}
torch_funcs = {name: getattr(torch_impl, name, None) for name in torch_impl.__all__}
for name, fn in numpy_funcs.items():
if fn is None:
print(f"\033[91mWarning: Function {name} is in numpy __all__ but not found in numpy module.\033[0m")
for name, fn in torch_funcs.items():
if fn is None:
print(f"\033[91mWarning: Function {name} is in torch __all__ but not found in torch module.\033[0m")
numpy_funcs = {name: fn for name, fn in numpy_funcs.items() if fn is not None}
torch_funcs = {name: fn for name, fn in torch_funcs.items() if fn is not None}
all = {**numpy_funcs, **torch_funcs}
if Path("utils3d/interface").exists():
shutil.rmtree("utils3d/interface")
Path("utils3d/interface").mkdir(exist_ok=True)
with open("utils3d/interface/__init__.pyi", "w", encoding="utf-8") as f:
f.write(inspect.cleandoc(
f"""
# Auto-generated interface file
from typing import List, Tuple, Dict, Union, Optional, Any, overload, Literal, Callable
from typing_extensions import Unpack
import numpy as numpy_
import torch as torch_
import nvdiffrast.torch
import numbers
from . import numpy, torch
import utils3d.numpy, utils3d.torch
"""
))
f.write("\n\n")
all_ = ', \n'.join('\"' + s + '\"' for s in all.keys())
f.write(f"__all__ = [{all_}]\n\n")
for fname, fn in itertools.chain(numpy_funcs.items(), torch_funcs.items()):
sig, doc = get_signature(fn), inspect.getdoc(fn)
f.write(f"@overload\n")
f.write(f"def {fname}{sig}:\n")
f.write(f" \"\"\"{doc}\"\"\"\n" if doc else "")
f.write(f" {fn.__module__}.{fn.__qualname__}\n\n")
with open("utils3d/interface/__init__.py", "w", encoding="utf-8") as f:
f.write(inspect.cleandoc(
f"""
# Auto-generated implementation redirecting to numpy/torch implementations
import sys
from typing import TYPE_CHECKING
import utils3d
from ..helpers import suppress_traceback
"""
))
f.write("\n\n")
all_ = ', \n'.join('\"' + s + '\"' for s in all.keys())
f.write(f"__all__ = [{all_}]\n\n")
f.write(inspect.getsource(_contains_tensor) + "\n\n")
f.write(inspect.getsource(_call_based_on_args) + "\n\n")
for fname in {**numpy_funcs, **torch_funcs}:
f.write(f'@suppress_traceback\n')
f.write(f"def {fname}(*args, **kwargs):\n")
f.write(f" if TYPE_CHECKING: # redirected to:\n")
f.write(f" {'utils3d.numpy.' + fname if fname in numpy_funcs else 'None'}, {'utils3d.torch.'+ fname if fname in torch_funcs else 'None'}\n")
f.write(f" return _call_based_on_args('{fname}', args, kwargs)\n\n")