diff --git a/main_classificaiton.py b/main_classificaiton.py index dace137..c8c8261 100644 --- a/main_classificaiton.py +++ b/main_classificaiton.py @@ -1,6 +1,7 @@ import torch import pandas as pd import numpy as np +import os from sklearn.metrics import accuracy_score from models.train_model import Train_Test @@ -141,6 +142,9 @@ def save_model(self, best_model, best_model_path): :param best_model_path: path for saving model :type best_model_path: str """ + + # make folder to save model + os.makedirs('./ckpt', exist_ok=True) # save model torch.save(best_model.state_dict(), best_model_path) @@ -223,4 +227,4 @@ def get_dataloader(self, x_data, y_data, batch_size, shuffle): # DataLoader 구축 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) - return data_loader \ No newline at end of file + return data_loader