-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
` # fine-tune model
if args.use_MLP:
logreg = MLP(num_features, n_classes, 4096)
logreg = logreg.to(device)
else:
logreg = nn.Sequential(nn.Linear(num_features, n_classes))
logreg = logreg.to(device)
# loss / optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)
# Train fine-tuned model
logreg.train()
resnet.train()
accs = []
for epoch in range(args.num_epochs):
print("======epoch {}======".format(epoch))
metrics = defaultdict(list)
for step, (h, y) in enumerate(train_loader):
h = h.to(device)
y = y.to(device)
outputs = logreg(resnet(h))
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# calculate accuracy and save metrics
accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
metrics["Loss/train"].append(loss.item())
metrics["Accuracy/train"].append(accuracy)
print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join(
[f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
if epoch % 1 == 0:
acc = test_whole(resnet, logreg, device, test_loader, args.model_path)
if epoch <= 100:
accs.append(acc)
test_whole(resnet, logreg, device, test_loader, args.model_path)
print(args.model_path)
print(f"Best one for 100 epoch is {max(accs):.4f}")`
I have a question about the semi_supervised_evaluation.py script in FedSSL. During the finetuning process, the parameters of the Adam optimizer are set as logreg.parameters(). Does this mean that during finetuning, the parameters of the encoder won't be changed, and only the logreg model will be trained? Is this correct?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels