-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGermanModel.py
More file actions
60 lines (48 loc) · 2.38 KB
/
GermanModel.py
File metadata and controls
60 lines (48 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Basics
from sklearn.metrics import balanced_accuracy_score
# Pytorch
import torch
from torch import nn, Tensor
# Huggingface & co.
from transformers import RobertaConfig, RobertaForMaskedLM, PreTrainedTokenizer
# Typing
from typing import Dict, List, Tuple, Union, Optional, Any
# Own Files
from MyTraining import MyTraining, BasicModelOutput
class GermanModel(nn.Module):
def __init__(self, vocab_size: int, max_seq_length: int, tokenizer: PreTrainedTokenizer):
super(GermanModel, self).__init__()
hidden_size = 512
self.tokenizer = tokenizer # for manual evaluation
config = RobertaConfig(
vocab_size = vocab_size,
hidden_size = hidden_size,
num_hidden_layers = 8,
num_attention_heads = 8,
intermediate_size = 4 * hidden_size,
max_position_embeddings=max_seq_length + 10, # some extra space needed, otherwise error when padded length = max_position_embeddings, but why? Because tokenizer should include special token into the max_length?
)
self.roberta = RobertaForMaskedLM(config)
self.config = self.roberta.config
# these hyperparameters are logged to wandb
self.wandb_config = {
"hidden_size": hidden_size,
"num_hidden_layers": config.num_hidden_layers,
"num_attention_heads": config.num_attention_heads,
"intermediate_size": config.intermediate_size,
"max_position_embeddings": config.max_position_embeddings,
}
def forward(self, input_ids, attention_mask, labels) -> BasicModelOutput:
output = self.roberta(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
return BasicModelOutput(
loss=output.loss,
logits=torch.tensor([], device=output.loss.device),
labels=torch.tensor([], device=output.loss.device)
)
def compute_metrics(self, logits: Tensor, labels: Tensor) -> float:
predictions = torch.argmax(logits, dim=-1)
# accuracy = (torch.sum(predictions == labels) / len(labels)).item()
balanced_accuracy = balanced_accuracy_score(labels.cpu(), predictions.cpu())
return balanced_accuracy
def compare_metrics(self, new_metric: float, best_metric: float) -> bool:
return new_metric > best_metric