-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcatboost_algorithm.py
More file actions
139 lines (119 loc) · 4.76 KB
/
catboost_algorithm.py
File metadata and controls
139 lines (119 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from catboost import CatBoostClassifier
import numpy as np
from .base import BaseAlgorithm
class CatBoostAlgorithm(BaseAlgorithm):
"""Implementação do CatBoost para classificação."""
def __init__(
self,
iterations=1000,
learning_rate=0.1,
depth=6,
l2_leaf_reg=3,
random_strength=1,
bagging_temperature=1,
border_count=254,
random_state=None,
early_stopping_rounds=None,
task_type='CPU',
thread_count=-1,
verbose=100
):
self.iterations = iterations
self.learning_rate = learning_rate
self.depth = depth
self.l2_leaf_reg = l2_leaf_reg
self.random_strength = random_strength
self.bagging_temperature = bagging_temperature
self.border_count = border_count
self.random_state = random_state
self.early_stopping_rounds = early_stopping_rounds
self.task_type = task_type
self.thread_count = thread_count
self.verbose = verbose
self.model = CatBoostClassifier(
iterations=iterations,
learning_rate=learning_rate,
depth=depth,
l2_leaf_reg=l2_leaf_reg,
random_strength=random_strength,
bagging_temperature=bagging_temperature,
border_count=border_count,
random_seed=random_state,
task_type=task_type,
thread_count=thread_count,
verbose=verbose
)
def fit(self, X, y):
"""Treina o modelo com early stopping opcional."""
if self.early_stopping_rounds is not None:
# Divide dados em treino e validação
n_samples = len(X)
n_val = int(0.2 * n_samples)
indices = np.random.permutation(n_samples)
val_idx = indices[:n_val]
train_idx = indices[n_val:]
X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
# Treina com early stopping
self.model.fit(
X_train, y_train,
eval_set=(X_val, y_val),
early_stopping_rounds=self.early_stopping_rounds,
verbose=self.verbose
)
else:
self.model.fit(X, y, verbose=self.verbose)
return self
def predict(self, X):
"""Faz previsões usando o modelo."""
return self.model.predict(X)
def predict_proba(self, X):
"""Retorna probabilidades das previsões."""
return self.model.predict_proba(X)
def get_params(self):
"""Retorna os parâmetros do modelo."""
return {
'iterations': self.iterations,
'learning_rate': self.learning_rate,
'depth': self.depth,
'l2_leaf_reg': self.l2_leaf_reg,
'random_strength': self.random_strength,
'bagging_temperature': self.bagging_temperature,
'border_count': self.border_count,
'random_state': self.random_state,
'early_stopping_rounds': self.early_stopping_rounds,
'task_type': self.task_type,
'thread_count': self.thread_count,
'verbose': self.verbose
}
def get_feature_importance(self, type='FeatureImportance'):
"""Retorna a importância das features."""
return self.model.get_feature_importance(type=type)
def plot_feature_importance(self, feature_names=None, top_k=None):
"""Plota a importância das features."""
self.model.plot_feature_importance(feature_names, top_k)
def plot_learning_curves(self):
"""Plota as curvas de aprendizado do modelo."""
self.model.plot_metrics()
def save_model(self, filepath):
"""Salva o modelo em formato binário."""
self.model.save_model(filepath)
def load_model(self, filepath):
"""Carrega o modelo de um arquivo binário."""
self.model.load_model(filepath)
def get_best_iteration(self):
"""Retorna a melhor iteração se early stopping foi usado."""
return self.model.get_best_iteration()
def get_best_score(self):
"""Retorna o melhor score se early stopping foi usado."""
return self.model.get_best_score()
def get_evals_result(self):
"""Retorna os resultados da avaliação durante o treinamento."""
return self.model.get_evals_result()
def get_model_params(self):
"""Retorna os parâmetros internos do modelo."""
return self.model.get_params()
def shrink(self, ntree_start=0, ntree_end=0):
"""Reduz o modelo para usar apenas as árvores especificadas."""
self.model.shrink(ntree_start, ntree_end)
return self