diff --git a/uq360/algorithms/variational_bayesian_neural_networks/bnn.py b/uq360/algorithms/variational_bayesian_neural_networks/bnn.py index 50b2914..f66d5aa 100644 --- a/uq360/algorithms/variational_bayesian_neural_networks/bnn.py +++ b/uq360/algorithms/variational_bayesian_neural_networks/bnn.py @@ -64,6 +64,9 @@ def fit(self, X, y): self """ + assert type(X) == torch.Tensor, f"Expected X type torch.Tensor but found {type(X)}" + assert type(y) == torch.Tensor, f"Expected y type torch.Tensor but found {type(y)}" + torch.manual_seed(1234) optimizer = torch.optim.Adam(self.net.parameters(), lr=self.config['step_size']) neg_elbo = torch.zeros([self.config['num_epochs'], 1])