Skip to content

Commit 0286b49

Browse files
fix onnx_compatible_mode (#101)
1 parent 72fdee9 commit 0286b49

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

moge/model/v2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class MoGeModel(nn.Module):
2525
points_head: ConvStack
2626
mask_head: ConvStack
2727
scale_head: MLP
28+
onnx_compatible_mode: bool
2829

2930
def __init__(self,
3031
encoder: Dict[str, Any],
@@ -63,6 +64,15 @@ def device(self) -> torch.device:
6364
def dtype(self) -> torch.dtype:
6465
return next(self.parameters()).dtype
6566

67+
@property
68+
def onnx_compatible_mode(self) -> bool:
69+
return getattr(self, "_onnx_compatible_mode", False)
70+
71+
@onnx_compatible_mode.setter
72+
def onnx_compatible_mode(self, value: bool):
73+
self._onnx_compatible_mode = value
74+
self.encoder.onnx_compatible_mode = value
75+
6676
@classmethod
6777
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
6878
"""

0 commit comments

Comments
 (0)