Skip to content

Fine-tuning Presto on dataset used in downstream task demo notebook #39

@hsotoparada

Description

@hsotoparada

Hi @gabrieltseng, I've read your paper and find it a really interesting work!
Thanks a lot for sharing your code as well!

I'm trying to adapt your downstream task notebook for finetuning the pretrained Presto model on the same dataset used in the notebook.

My approach is based on the README instructions, the code in the notebook and the functions evaluate and finetune found in cropharvest_eval.py. The main part of my code is:

## using Presto for finetuning
# based on functions eval and finetune in presto/eval/cropharvest_eval.py

pretrained_model = presto.Presto.load_pretrained()
print(type(pretrained_model))
pretrained_model.eval()

# build finetuning model: encoder + linear transformation (FinetuningHead)

num_outputs = 2
regression = False

finetuning_model = pretrained_model.construct_finetuning_model(
    num_outputs=num_outputs,
    regression=regression,
)

opt = Adam(finetuning_model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss(reduction="mean")

# train finetuning model

max_epochs = 5

train_loss = []
for epoch in range(max_epochs):
    print(f"Training for epoch: {(epoch+1):03}")
    finetuning_model.train()
    epoch_train_loss = 0.0

    for (x, mask, dw, latlons, y, month) in tqdm(train_dl):

        # zero the parameter gradients
        opt.zero_grad()

        # forward + backward + optimize
        preds = finetuning_model(
            x,
            dynamic_world=dw,
            mask=mask,
            latlons=latlons,
            month=month,
        )
        loss = loss_fn(preds, y.type(torch.LongTensor))
        epoch_train_loss += loss.item()
        loss.backward()
        opt.step()

    train_loss.append(epoch_train_loss / len(train_dl))

# make predictions using finetuning model

test_preds = []
for (x, mask, dw, latlons, month) in tqdm(test_dl):

    x = x.to(device)
    dw = dw.to(device).long()
    mask = mask.to(device)
    latlons = latlons.to(device).float()
    month = month.to(device)

    with torch.no_grad():
        finetuning_model.eval()
        preds = (
            finetuning_model(
                x, dynamic_world=dw, mask=mask, latlons=latlons, month=month
            )
            .cpu()
            .numpy()
        )
        # preds = np.argmax(preds, axis=-1)
        test_preds.append(preds)

print("predicting with finetuning model...")
print(len(test_preds))
print(test_preds[0])

And from the print outputs I see for example that the predictions in test_preds[0] are:

predicting with finetuning model...
53
[[-9.505264 , 9.816492 ],
[-9.501129 , 9.811971 ],
[-9.496433 , 9.806909 ],
[-9.49617 , 9.806579 ],
[-9.495665 , 9.805991 ],
[-9.4937105, 9.803866 ],
[-9.497982 , 9.808611 ],
[-9.507018 , 9.818317 ],
[-9.520019 , 9.832625 ],
[-9.512251 , 9.824137 ],
...
[-9.4941025, 9.804452 ],
[-9.506046 , 9.817224 ],
[-9.48634 , 9.795685 ],
[-9.496958 , 9.807267 ]]

and I get similar numbers for the remaining elements in test_preds.
But if these numbers are predictions I would expect them to be probabilities that sum up to 1, or that should not be the case here?

I guess there may be some step I'm missing but I can't figure out what it could be.
Could you please give any hint on this? I would really appreciate your help.

Cheers,
Hugo

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions