Skip to content

Missing labels in loss kwargs causing issue for electra model #140

@aditya0by0

Description

@aditya0by0

When I try use an electra checkpoint for prediction, the recently merged code in the PR #130 causes an issue.

  File "chebai/result/prediction.py", line 171, in <module>
    CLI(MainPredictor, as_positional=False)
  File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/jsonargparse/_cli.py", line 23, in CLI
    return auto_cli(*args, _stacklevel=3, **kwargs)
  File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/jsonargparse/_cli.py", line 102, in auto_cli
    return _run_component(components, init)
  File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/jsonargparse/_cli.py", line 210, in _run_component
    return component(**cfg)
  File "chebai/result/prediction.py", line 152, in predict_from_file
    predictor.predict_from_file(
  File "chebai/result/prediction.py", line 89, in predict_from_file
    preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings)
  File "/home/staff/a/akhedekar/miniconda3/envs/gnn/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "chebai/result/prediction.py", line 128, in predict_smiles
    self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams)
  File "/home/staff/a/akhedekar/python-chebai/chebai/models/base.py", line 247, in predict_step
    pr, _ = self._get_prediction_and_labels(data, labels, model_output)
  File "/home/staff/a/akhedekar/python-chebai/chebai/models/electra.py", line 330, in _get_prediction_and_labels
    d = d * (~missing_labels).int().to(
RuntimeError: The size of tensor a (1528) must match the size of tensor b (2) at non-singleton dimension 1

Additional Information

shape of missing labels is [2, 2]

with following values. [[False, False], [False, False]]

the missing labels are added using collator with new logic (#130). Relevant part of code below:

            missing_labels = [
                d.get("missing_labels", [False for _ in y[0]]) for d in data
            ]

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions