Skip to content

Argument in test_on_dataset and train_and_test_on_datasets functions to write "val_" metrics instead of "test_" #266

@arthur-thuy

Description

@arthur-thuy

Is your feature request related to a problem? Please describe.
The MetricMixin class only creates "train_" and "test_" metrics in the add_metric method. This works fine when only using a training and test set.

However, when also using a validation set such as in the snippets below, this presents a problem.

for al_step in range(N_ALSTEP):
    _ = wrapper.train_on_dataset(
        train_dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda
    )
    _ = wrapper.test_on_dataset(val_dataset, BATCH_SIZE)
    _ = wrapper.test_on_dataset(test_dataset, BATCH_SIZE)
    metrics = wrapper.get_metrics()
    # Label the next most uncertain items.
    if not active_loop.step():
        # We're done!
        break
for al_step in range(N_ALSTEP):
    _ = wrapper.train_and_test_on_datasets(
        train_dataset, val_dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda
    )
    _ = wrapper.test_on_dataset(test_dataset, BATCH_SIZE)
    metrics = wrapper.get_metrics()
    # Label the next most uncertain items.
    if not active_loop.step():
        # We're done!
        break

Here, the true validation metrics are recorded as "test_" and are later overwritten by the true test metrics also recorded in "test_".

Describe the solution you'd like
It would be nice if the test_on_dataset and train_and_test_on_datasets functions have an argument to specify which metric is written ("val_" or "test_").

Describe alternatives you've considered
A simple but cumbersome solution is to create a dict and copy all the "test_" metrics corresponding to the true validation metrics in the dict as "val_", as follows:

trainval_hist = wrapper.train_and_test_on_datasets(...)
trainval_last = trainval_hist[-1]  # NOTE: take log at last epoch
metrics[len(active_set)] = {
    "train_loss": trainval_last["train_loss"],
    "train_accuracy": trainval_last["train_accuracy"],
    "dataset_size": len(active_set),
    "epochs_trained": len(trainval_hist),
    "val_loss": trainval_last["test_loss"],
    "val_accuracy": trainval_last["test_accuracy"],
}

Additional context
/

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions