From ad2f6c7e641b793c3156b924c95a0c4cff425325 Mon Sep 17 00:00:00 2001 From: "Terlouw, Barbara" Date: Tue, 7 Oct 2025 10:40:41 +0200 Subject: [PATCH 1/2] Code cleanup --- .../test/test_calculate_sample_weights.py | 2 +- mewtwo/machine_learning/transformer/finetune_bert.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mewtwo/machine_learning/data_preparation/test/test_calculate_sample_weights.py b/mewtwo/machine_learning/data_preparation/test/test_calculate_sample_weights.py index 30f5305..1bba71f 100644 --- a/mewtwo/machine_learning/data_preparation/test/test_calculate_sample_weights.py +++ b/mewtwo/machine_learning/data_preparation/test/test_calculate_sample_weights.py @@ -3,6 +3,7 @@ from mewtwo.machine_learning.data_preparation.calculate_sample_weights import get_sample_weights + class CalculateSampleWeights(unittest.TestCase): def test_get_sample_weights(self): dataset_1 = [0.1, 0.3, 0.5, 0.7, 0.9] @@ -22,6 +23,5 @@ def assertNearlyEqual(self, list_1, list_2): self.fail(f"Lists are not equal: {list_1}, {list_2}. \n First mismatching element: {i} ([{element_1}], [{element_2}])") - if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/mewtwo/machine_learning/transformer/finetune_bert.py b/mewtwo/machine_learning/transformer/finetune_bert.py index 74d9dbd..860bec5 100644 --- a/mewtwo/machine_learning/transformer/finetune_bert.py +++ b/mewtwo/machine_learning/transformer/finetune_bert.py @@ -1,7 +1,8 @@ import argparse import os +from typing import TextIO -from mewtwo.machine_learning.transformer.model import load_model +from mewtwo.machine_learning.transformer.model import load_model, Model from mewtwo.machine_learning.transformer.config.config_types import SchedulerType, EarlyStoppingMetricType from scipy.stats import pearsonr, spearmanr @@ -34,7 +35,7 @@ def parse_arguments() -> argparse.Namespace: return args -def metric_has_improved(old_metric, new_metric, metric_type): +def metric_has_improved(old_metric: float, new_metric: float, metric_type: EarlyStoppingMetricType) -> bool: if metric_type in EarlyStoppingMetricType.MAX_METRICS: if new_metric > old_metric: return True @@ -49,7 +50,7 @@ def metric_has_improved(old_metric, new_metric, metric_type): raise ValueError(f"Unrecognised early stopping metric: {metric_type.name}") -def get_metric(eval_loss, pearson, spearman, metric_type): +def get_metric(eval_loss: float, pearson: float, spearman: float, metric_type: EarlyStoppingMetricType) -> float: if metric_type == EarlyStoppingMetricType.PEARSON_R: return pearson elif metric_type == EarlyStoppingMetricType.SPEARMAN_R: @@ -60,7 +61,7 @@ def get_metric(eval_loss, pearson, spearman, metric_type): raise ValueError(f"Unrecognised early stopping metric: {metric_type.name}") -def finetune(model, summary, epochs, out_dir, header=False): +def finetune(model: Model, summary: TextIO, epochs: int, out_dir: str, header=False) -> None: if header: summary.write("epoch\taverage_train_loss\taverage_eval_loss\tpearsonr\tspearmanr\n") config_file = os.path.join(out_dir, "model.config") From 2cd77b5893b51fa251df8f4dd83466dec8ad76d2 Mon Sep 17 00:00:00 2001 From: "Terlouw, Barbara" Date: Tue, 7 Oct 2025 10:41:03 +0200 Subject: [PATCH 2/2] Add 'full model finetuning' config type --- mewtwo/machine_learning/transformer/config/config_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mewtwo/machine_learning/transformer/config/config_types.py b/mewtwo/machine_learning/transformer/config/config_types.py index fed9d49..e84be64 100644 --- a/mewtwo/machine_learning/transformer/config/config_types.py +++ b/mewtwo/machine_learning/transformer/config/config_types.py @@ -33,6 +33,7 @@ def from_string_description(string_description) -> "LossFunctionType": class FinetuningType(Enum): LINEAR_HEAD = 1 ADAPTER = 2 + FULL = 3 @staticmethod def from_string_description(string_description) -> "FinetuningType":