-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
146 lines (124 loc) Β· 5.03 KB
/
train_model.py
File metadata and controls
146 lines (124 loc) Β· 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
BMS AI Model Training
Trains a Random Forest classifier to predict battery condition or application suitability.
Usage:
python train_model.py # Train condition model
python train_model.py --target application # Train application model
python train_model.py --source battery_data.csv # Train on specific CSV
"""
import argparse
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from data_pipeline import load_data, preprocess, FEATURE_COLUMNS
def train(data_source="synthetic", num_samples=5000, target="condition"):
"""Train the BMS classifier and save model artifacts."""
target_label = "Condition" if target == "condition" else "Application Suitability"
print("=" * 60)
print(f" π BMS AI Model Training β {target_label}")
print("=" * 60)
# --- 1. Load Data ---
print(f"\nπ¦ Loading data (source: {data_source})...")
df = load_data(data_source, num_samples=num_samples)
print(f" Dataset: {df.shape[0]} samples, {df.shape[1]} features")
print(f" {target_label} distribution:")
for cond, count in df[target].value_counts().items():
pct = count / len(df) * 100
print(f" {cond:>18s}: {count:>5d} ({pct:.1f}%)")
# --- 2. Preprocess ---
print("\nβοΈ Preprocessing...")
X, y, label_encoder = preprocess(df, target=target)
available_features = [col for col in FEATURE_COLUMNS if col in df.columns]
print(f" Features used: {available_features}")
print(f" Classes: {list(label_encoder.classes_)}")
# --- 3. Train/Test Split ---
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"\nπ Split: {len(X_train)} train / {len(X_test)} test")
# --- 4. Scale Features ---
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# --- 5. Train Model ---
print("\nπ§ Training Random Forest classifier...")
model = RandomForestClassifier(
n_estimators=200,
max_depth=15,
min_samples_split=5,
min_samples_leaf=2,
class_weight="balanced",
random_state=42,
n_jobs=-1,
)
model.fit(X_train_scaled, y_train)
# --- 6. Evaluate ---
y_pred = model.predict(X_test_scaled)
accuracy = accuracy_score(y_test, y_pred)
print(f"\n{'=' * 60}")
print(f" π MODEL EVALUATION RESULTS")
print(f"{'=' * 60}")
print(f"\n Accuracy: {accuracy:.4f} ({accuracy * 100:.1f}%)")
print(f"\n Classification Report:")
target_names = list(label_encoder.classes_)
report = classification_report(y_test, y_pred, target_names=target_names)
for line in report.split("\n"):
print(f" {line}")
print(f"\n Confusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
print(f" {'':>10s} ", end="")
for name in target_names:
print(f"{name:>10s}", end="")
print()
for i, row in enumerate(cm):
print(f" {target_names[i]:>10s} ", end="")
for val in row:
print(f"{val:>10d}", end="")
print()
# --- 7. Feature Importance ---
print(f"\n Feature Importance:")
importances = model.feature_importances_
sorted_idx = np.argsort(importances)[::-1]
for idx in sorted_idx:
bar = "β" * int(importances[idx] * 40)
print(f" {available_features[idx]:>16s}: {importances[idx]:.4f} {bar}")
# --- 8. Save Artifacts ---
if target == "condition":
model_file, scaler_file, encoder_file = "bms_model.pkl", "scaler.pkl", "label_encoder.pkl"
else:
model_file, scaler_file, encoder_file = "app_model.pkl", "app_scaler.pkl", "app_label_encoder.pkl"
print(f"\nπΎ Saving model artifacts...")
joblib.dump(model, model_file)
joblib.dump(scaler, scaler_file)
joblib.dump(label_encoder, encoder_file)
print(f" β
{model_file}")
print(f" β
{scaler_file}")
print(f" β
{encoder_file}")
print(f"\n{'=' * 60}")
print(f" β
Training complete! Accuracy: {accuracy * 100:.1f}%")
print(f"{'=' * 60}")
return model, scaler, label_encoder, accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train BMS AI Model")
parser.add_argument(
"--source",
default="synthetic",
help="Data source: 'synthetic' or path to CSV file",
)
parser.add_argument(
"--samples",
type=int,
default=5000,
help="Number of synthetic samples (only used with synthetic source)",
)
parser.add_argument(
"--target",
default="condition",
choices=["condition", "application"],
help="Target to train: 'condition' or 'application'",
)
args = parser.parse_args()
train(data_source=args.source, num_samples=args.samples, target=args.target)