-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
When I try use an electra checkpoint for prediction, the recently merged code in the PR #130 causes an issue.
File "chebai/result/prediction.py", line 171, in <module>
CLI(MainPredictor, as_positional=False)
File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/jsonargparse/_cli.py", line 23, in CLI
return auto_cli(*args, _stacklevel=3, **kwargs)
File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/jsonargparse/_cli.py", line 102, in auto_cli
return _run_component(components, init)
File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/jsonargparse/_cli.py", line 210, in _run_component
return component(**cfg)
File "chebai/result/prediction.py", line 152, in predict_from_file
predictor.predict_from_file(
File "chebai/result/prediction.py", line 89, in predict_from_file
preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings)
File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "chebai/result/prediction.py", line 128, in predict_smiles
self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams)
File "/home/staff/a/akhedekar/python-chebai/chebai/models/base.py", line 247, in predict_step
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
File "/home/staff/a/akhedekar/python-chebai/chebai/models/electra.py", line 330, in _get_prediction_and_labels
d = d * (~missing_labels).int().to(
RuntimeError: The size of tensor a (1528) must match the size of tensor b (2) at non-singleton dimension 1Additional Information
shape of missing labels is [2, 2]
with following values. [[False, False], [False, False]]
the missing labels are added using collator with new logic (#130). Relevant part of code below:
missing_labels = [
d.get("missing_labels", [False for _ in y[0]]) for d in data
]Metadata
Metadata
Assignees
Labels
No labels