diff --git a/hyperlight/convert.py b/hyperlight/convert.py index 1625844..a4609e1 100644 --- a/hyperlight/convert.py +++ b/hyperlight/convert.py @@ -85,8 +85,8 @@ def hypernetize_single( def hypernetize( model: nn.Module, - modules: Optional[List[nn.Module]] = None, - parameters: Optional[List[nn.Parameter]] = None, + modules: Optional[Union[List[nn.Module], Dict[str, nn.Module]]] = None, + parameters: Optional[Union[List[nn.Parameter], Dict[str, nn.Parameter]]] = None, return_values: bool = False, inplace: bool = True, ):