Skip to content

I have a question about the semi_supervised_evaluation.py script in FedSSL. #26

@ss3b3

Description

@ss3b3

` # fine-tune model
if args.use_MLP:
logreg = MLP(num_features, n_classes, 4096)
logreg = logreg.to(device)
else:
logreg = nn.Sequential(nn.Linear(num_features, n_classes))
logreg = logreg.to(device)

# loss / optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)

# Train fine-tuned model
logreg.train()
resnet.train()
accs = []
for epoch in range(args.num_epochs):
    print("======epoch {}======".format(epoch))
    metrics = defaultdict(list)
    for step, (h, y) in enumerate(train_loader):
        h = h.to(device)
        y = y.to(device)

        outputs = logreg(resnet(h))

        loss = criterion(outputs, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # calculate accuracy and save metrics
        accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
        metrics["Loss/train"].append(loss.item())
        metrics["Accuracy/train"].append(accuracy)

    print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join(
        [f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))

    if epoch % 1 == 0:
        acc = test_whole(resnet, logreg, device, test_loader, args.model_path)
        if epoch <= 100:
            accs.append(acc)
test_whole(resnet, logreg, device, test_loader, args.model_path)
print(args.model_path)
print(f"Best one for 100 epoch is {max(accs):.4f}")`

I have a question about the semi_supervised_evaluation.py script in FedSSL. During the finetuning process, the parameters of the Adam optimizer are set as logreg.parameters(). Does this mean that during finetuning, the parameters of the encoder won't be changed, and only the logreg model will be trained? Is this correct?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions