-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelper.py
More file actions
executable file
·57 lines (48 loc) · 1.78 KB
/
helper.py
File metadata and controls
executable file
·57 lines (48 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
'''
Setup utils.py which contains seed_everything, get_preds_ordered,
print_test_metrics,get_data,get_true_labels
'''
import os
import torch
import random
import numpy as np
from sklearn.metrics import f1_score,accuracy_score
from fastai import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks import *
from pathlib import Path
from sklearn.model_selection import train_test_split
def seed_everything(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def get_preds_ordered(ds_type:DatasetType,learner:Learner):
"""
The get_preds method does not yield the elements in order by default
we borrow the code from the RNNLearner to resort the elements into their correct order
"""
preds = learner.get_preds(ds_type)[0].detach().cpu().numpy()
sampler = [i for i in learner.data.dl(ds_type).sampler]
reverse_sampler = np.argsort(sampler)
return preds[reverse_sampler, :]
def print_test_metrics(true_labels,learner):
preds = get_preds_ordered(DatasetType.Test,learner)
pred_values = np.argmax(preds, axis = 1)
acc, f1s = accuracy_score(true_labels,pred_values),f1_score(true_labels,pred_values)
print(f'The accuracy is {round(acc*100,2)}%, the f1_score is {round(f1s*100,2)}%')
def get_data(path,train_file:str,test_file:str):
df_train = pd.read_csv(Path(path)/train_file)
if test_file != None:
df_test = pd.read_csv(Path(path)/test_file)
else:
df_test = None
return df_train,df_test
def get_true_labels(df,config_label:list)->np.array:
item = config_label[0]
true_labels = np.array(df[item].tolist())
return true_labels