-
Notifications
You must be signed in to change notification settings - Fork 55
Description
Each parametrization has its own parameters. The parameters are eventually passed to the optimizer. The base parametrization class defines what counts as parameters. This is super ugly and depends on the class of self. A more elegant and less error-prone alternative is to define the parameters as the union of the estimators (that define the parametrization)'s parameters. Each estimator contains either a module or a tensor.
An idea would be to make the logZ estimator's tensor a module (that inherits from GFNModule), and make an abstract property (called gfn_parameters) that each module needs to implement. Those of Uniform would be empty for example. The nn.Modules can call self.named_parameters() to implement gfn_parameters. Inheritance there should be revisited. Tabular shouldn't need to inherit from nn.Module.