-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ml_validator.py
More file actions
97 lines (75 loc) · 3.05 KB
/
train_ml_validator.py
File metadata and controls
97 lines (75 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
"""Script to train the ML Strategy Validator"""
import logging
from core.quant.validators.ml_validator import MLStrategyValidator
def train_ml_validator(data_days: int = 60, symbol: str = "XAUUSD") -> bool:
"""Train the ML validator model.
Args:
data_days: Number of days of historical data to use for training.
symbol: Trading symbol to train on.
Returns:
True if training completes successfully, False otherwise.
"""
print("🤖 Training ML Strategy Validator")
print("=" * 50)
print(f"📊 Using {data_days} days of {symbol} data")
print()
try:
return _extracted_from_train_ml_validator_17(symbol, data_days)
except Exception as e:
print(f"❌ Error during training: {e}")
logging.exception("Training error")
return False
# TODO Rename this here and in `train_ml_validator`
def _extracted_from_train_ml_validator_17(symbol, data_days):
# Initialize validator
validator = MLStrategyValidator()
# Train model
print("Training model with historical data...")
results = validator.train_from_history(symbol=symbol, n_days=data_days)
if "error" in results:
print(f"❌ Training failed: {results['error']}")
return False
# Display results
print("✅ Training completed successfully!")
print("📊 Training Results:")
print(f" - Training samples: {results['training_samples']}")
print(f" - Test samples: {results['test_samples']}")
print(f" - Training accuracy: {results['train_accuracy']:.3f}")
print(f" - Test accuracy: {results['test_accuracy']:.3f}")
print("\n📈 Classification Report:")
report = results['classification_report']
for class_label, metrics in report.items():
if isinstance(metrics, dict):
precision = metrics.get('precision', 0)
recall = metrics.get('recall', 0)
f1 = metrics.get('f1-score', 0)
support = metrics.get('support', 0)
print(f" {class_label}: Precision={precision:.3f}, "
f"Recall={recall:.3f}, F1={f1:.3f}, Support={support}")
print("\n🎯 Feature Importances:")
importances = results['feature_importances']
sorted_features = sorted(importances.items(), key=lambda x: x[1], reverse=True)
for feature, importance in sorted_features:
print(f" {feature}: {importance:.3f}")
print(f"\n💾 Model saved to: {validator.model_path}")
return True
if __name__ == "__main__":
import argparse
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Train ML Strategy Validator")
parser.add_argument(
"--days",
type=int,
default=60,
help="Number of days of historical data (default: 60)",
)
parser.add_argument(
"--symbol",
type=str,
default="XAUUSD",
help="Trading symbol (default: XAUUSD)",
)
args = parser.parse_args()
success = train_ml_validator(data_days=args.days, symbol=args.symbol)
exit(0 if success else 1)