Skip to content

Consistency of regression layers (student-t and standard) #12

@brunzema

Description

@brunzema

Hi,

just found a subtle bug / inconsistency in the regression layers that I recognized when creating posterior sample from the predictive.

For the standard MVN case, W is defined as a method:

    def W(self):
        cov_diag = torch.exp(self.W_logdiag)
        if self.W_dist == Normal:
            cov = self.W_dist(self.W_mean, cov_diag)
        elif self.W_dist == DenseNormal:
            tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
            cov = self.W_dist(self.W_mean, tril)
        elif self.W_dist == LowRankNormal:
            cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag)

        return cov

whereas for the t-VBLL regression layer, it is defined as a property:

    @property
    def W(self):
        cov_diag = torch.exp(self.W_logdiag)
        if self.W_dist == Normal:
            cov = self.W_dist(self.W_mean, cov_diag)
        elif self.W_dist == DenseNormal:
            tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
            cov = self.W_dist(self.W_mean, tril)
        elif self.W_dist == LowRankNormal:
            cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag)

        return cov

This than alters the way to sample from W:

  • for VBLL: layer.W().rsample()
  • for tVBLL: layer.W.rsample()

I personally prefer W as a property. Happy to create a PR for this but wanted to double check with you guys.


EDIT: Just checked, same holds for the classification case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions