Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mewtwo.machine_learning.data_preparation.calculate_sample_weights import get_sample_weights


class CalculateSampleWeights(unittest.TestCase):
def test_get_sample_weights(self):
dataset_1 = [0.1, 0.3, 0.5, 0.7, 0.9]
Expand All @@ -22,6 +23,5 @@ def assertNearlyEqual(self, list_1, list_2):
self.fail(f"Lists are not equal: {list_1}, {list_2}. \n First mismatching element: {i} ([{element_1}], [{element_2}])")



if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions mewtwo/machine_learning/transformer/config/config_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def from_string_description(string_description) -> "LossFunctionType":
class FinetuningType(Enum):
LINEAR_HEAD = 1
ADAPTER = 2
FULL = 3

@staticmethod
def from_string_description(string_description) -> "FinetuningType":
Expand Down
9 changes: 5 additions & 4 deletions mewtwo/machine_learning/transformer/finetune_bert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import os
from typing import TextIO

from mewtwo.machine_learning.transformer.model import load_model
from mewtwo.machine_learning.transformer.model import load_model, Model
from mewtwo.machine_learning.transformer.config.config_types import SchedulerType, EarlyStoppingMetricType

from scipy.stats import pearsonr, spearmanr
Expand Down Expand Up @@ -34,7 +35,7 @@ def parse_arguments() -> argparse.Namespace:
return args


def metric_has_improved(old_metric, new_metric, metric_type):
def metric_has_improved(old_metric: float, new_metric: float, metric_type: EarlyStoppingMetricType) -> bool:
if metric_type in EarlyStoppingMetricType.MAX_METRICS:
if new_metric > old_metric:
return True
Expand All @@ -49,7 +50,7 @@ def metric_has_improved(old_metric, new_metric, metric_type):
raise ValueError(f"Unrecognised early stopping metric: {metric_type.name}")


def get_metric(eval_loss, pearson, spearman, metric_type):
def get_metric(eval_loss: float, pearson: float, spearman: float, metric_type: EarlyStoppingMetricType) -> float:
if metric_type == EarlyStoppingMetricType.PEARSON_R:
return pearson
elif metric_type == EarlyStoppingMetricType.SPEARMAN_R:
Expand All @@ -60,7 +61,7 @@ def get_metric(eval_loss, pearson, spearman, metric_type):
raise ValueError(f"Unrecognised early stopping metric: {metric_type.name}")


def finetune(model, summary, epochs, out_dir, header=False):
def finetune(model: Model, summary: TextIO, epochs: int, out_dir: str, header=False) -> None:
if header:
summary.write("epoch\taverage_train_loss\taverage_eval_loss\tpearsonr\tspearmanr\n")
config_file = os.path.join(out_dir, "model.config")
Expand Down
Loading