-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
357 lines (304 loc) · 16.5 KB
/
main.py
File metadata and controls
357 lines (304 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import os
import sys
import argparse
from datetime import datetime
# Add the project root directory to PATH
src_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, src_path)
from examples.model_training import train_model
from examples.backtesting_example import backtest_model, simulate_real_time_trading, run_comprehensive_backtest
# Import MA strategy backtesting functions
from src.ma_strategy import *
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Trading Event Detection and Model Training Project')
# Main execution mode
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'backtest', 'comprehensive', 'simulate', 'all', 'ma_strategy'], # kd_long 改为 ma_strategy
help='Execution mode: train (train model), backtest (backtest model), '
'comprehensive (training+validation backtest), simulate (simulate trading), '
'ma_strategy (run MA strategy), all (all modes except MA)')
# Data paths
parser.add_argument('--data', type=str, help='Data file path')
parser.add_argument('--validation-data', type=str, help='Validation data file path')
# Model settings
parser.add_argument('--model', type=str, default='lightgbm',
choices=['randomforest', 'gradientboosting', 'xgboost', 'lightgbm'],
help='Model type')
parser.add_argument('--model-path', type=str, help='Trained model path')
parser.add_argument('--feature-path', type=str, help='Feature list path')
parser.add_argument('--scaler-path', type=str, help='Scaler path')
# Backtesting parameters
parser.add_argument('--long-threshold', type=float, default=0.0026, help='Long position probability threshold')
parser.add_argument('--short-threshold', type=float, default=0.0026, help='Short position probability threshold')
parser.add_argument('--save-excel', action='store_true', help='Save detailed Excel results (default: True)')
parser.add_argument('--no-excel', action='store_true', help='Disable saving Excel results')
parser.add_argument('--run-id', type=str, help='Unique identifier for this run (default: timestamp)')
# Enhanced visualization options
parser.add_argument('--separate-long-short', action='store_true', default=True,
help='Perform separate analysis for long and short trades')
parser.add_argument('--no-separate-analysis', action='store_true',
help='Disable separate long/short analysis')
parser.add_argument('--enhanced-charts', action='store_true', default=True,
help='Use enhanced chart visualization')
parser.add_argument('--simple-charts', action='store_true',
help='Use simple chart visualization')
parser.add_argument('--no-show-plots', action='store_true',
help='Do not display plots (useful for batch processing)')
# Training parameters
parser.add_argument('--trials', type=int, default=100, help='Number of hyperparameter optimization trials')
# MA Strategy parameters
parser.add_argument('--ma-period', type=int, default=5, help='Period for Moving Average calculation')
parser.add_argument('--commission', type=float, default=0.0, help='Trading commission cost per trade')
parser.add_argument('--price-col', type=str, default='Close', help='Price column name for MA strategy')
parser.add_argument('--min-hold-periods', type=int, default=5, help='Minimum holding periods for MA strategy')
parser.add_argument('--optimize', action='store_true', help='Run parameter optimization for MA strategy')
return parser.parse_args()
def main():
"""Main program entry point"""
from src.logger import setup_logger
get_logger = lambda name: setup_logger(name)
from src.validation import ValidationError, validate_file_path, validate_model_type
logger = get_logger("main")
try:
args = parse_arguments()
# Set data directory and model directory
data_dir = os.path.join(src_path, 'data', 'raw')
models_dir = os.path.join(src_path, 'models')
# Ensure directories exist
for directory in [data_dir, models_dir]:
os.makedirs(directory, exist_ok=True)
# Process argument overrides
save_excel = not args.no_excel if hasattr(args, 'no_excel') else args.save_excel
separate_long_short = not args.no_separate_analysis if hasattr(args, 'no_separate_analysis') else args.separate_long_short
enhanced_charts = not args.simple_charts if hasattr(args, 'simple_charts') else args.enhanced_charts
show_plots = not args.no_show_plots if hasattr(args, 'no_show_plots') else True
# Generate run_id if not provided
run_id = args.run_id if args.run_id else datetime.now().strftime("%Y%m%d_%H%M%S")
# Validate and set data paths
if args.data is None:
args.data = os.path.join(data_dir, "TX00_training.xlsx")
else:
validate_file_path(args.data, must_exist=True, allowed_extensions=['.xlsx', '.csv'])
if args.validation_data is None:
args.validation_data = os.path.join(data_dir, "TX00_validation.xlsx")
else:
validate_file_path(args.validation_data, must_exist=True, allowed_extensions=['.xlsx', '.csv'])
# Validate model type if specified
if args.mode in ['train', 'backtest', 'comprehensive', 'simulate', 'all']:
validate_model_type(args.model, allowed_types=['randomforest', 'gradientboosting', 'xgboost', 'lightgbm'])
except ValidationError as e:
logger.error(f"Validation error: {str(e)}")
print(f"Error: {str(e)}")
return 1
except Exception as e:
logger.error(f"Unexpected error during initialization: {str(e)}", exc_info=True)
print(f"Unexpected error: {str(e)}")
return 1
# Handle MA strategy mode
if args.mode == 'ma_strategy':
logger.info("Starting MA Strategy Backtest")
print("\n===== Starting MA Strategy Backtest =====")
# Set custom save path using dynamic path construction
custom_path = os.path.join(src_path, 'results', 'ma_strategy')
if not os.path.exists(custom_path):
os.makedirs(custom_path, exist_ok=True)
try:
# Validate input parameters
from src.validation import validate_numeric_parameter
validate_numeric_parameter(args.ma_period, 'ma_period', min_value=1, max_value=500)
validate_numeric_parameter(args.commission, 'commission', min_value=0.0)
# Run MA strategy backtest, unpack returned tuple
metrics_result, df_result = backtest_ma_strategy(
data_path=args.validation_data,
ma_period=args.ma_period,
price_col=args.price_col,
save_excel=save_excel,
run_id=run_id,
base_dir=src_path,
custom_path=custom_path,
enhanced_charts=enhanced_charts,
show_plots=show_plots,
commission=args.commission
)
# Print MA strategy metrics
print("\nMA Strategy Results:")
print("-" * 50)
for key, value in metrics_result.items():
if isinstance(value, float):
print(f"{key}: {value:.4f}")
else:
print(f"{key}: {value}")
except ValidationError as e:
logger.error(f"Validation error in MA strategy: {str(e)}")
print(f"\nValidation error: {str(e)}")
return 1
except FileNotFoundError as e:
logger.error(f"File not found: {str(e)}")
print(f"\nFile not found: {str(e)}")
return 1
except Exception as e:
logger.error(f"Error occurred while executing MA strategy: {str(e)}", exc_info=True)
print(f"\nError occurred while executing MA strategy: {str(e)}")
print("Using simplified cumulative return comparison...")
try:
ma_periods = [3, 5, 10, 15, 20, 30, 50]
chart_path, results = create_cumulative_comparison_chart(
args.validation_data,
ma_periods=ma_periods,
price_col=args.price_col,
custom_path=custom_path
)
print(f"\nCumulative profit/loss comparison chart saved to: {chart_path}")
except Exception as e2:
logger.error(f"Failed to create fallback chart: {str(e2)}")
print(f"Failed to create fallback chart: {str(e2)}")
return 1
logger.info("MA Strategy backtest completed successfully")
print("\nMA Strategy backtest completed successfully!")
print("\nProgram execution completed!")
return 0
# Execute selected mode for ML model
try:
if args.mode in ['train', 'all']:
logger.info("Starting Model Training")
print("\n===== Starting Model Training =====")
try:
model, feature_engineering, metrics = train_model(args.data, args.model, args.trials)
# If 'all' mode, save model path for subsequent steps
if args.mode == 'all':
args.model_path = os.path.join(models_dir, f"{args.model}.joblib")
args.feature_path = os.path.join(models_dir, f"features.xlsx")
args.scaler_path = os.path.join(models_dir, f"scaler.joblib")
except FileNotFoundError as e:
logger.error(f"Training data file not found: {str(e)}")
print(f"Error: Training data file not found: {str(e)}")
return 1
except Exception as e:
logger.error(f"Error during model training: {str(e)}", exc_info=True)
print(f"Error during model training: {str(e)}")
return 1
if args.mode in ['backtest', 'all']:
logger.info("Starting Model Backtesting")
print("\n===== Starting Model Backtesting =====")
# Set default paths if not specified
if args.model_path is None:
args.model_path = os.path.join(models_dir, f"{args.model}.joblib")
if args.feature_path is None:
args.feature_path = os.path.join(models_dir, f"features.xlsx")
if args.scaler_path is None:
args.scaler_path = os.path.join(models_dir, f"scaler.joblib")
# Validate model files exist
try:
validate_file_path(args.model_path, must_exist=True)
validate_file_path(args.feature_path, must_exist=True)
validate_file_path(args.scaler_path, must_exist=True)
except ValidationError as e:
logger.error(f"Model file validation error: {str(e)}")
print(f"Error: {str(e)}")
return 1
try:
# Use the enhanced backtest_model function with new parameters
metrics = backtest_model(
args.model_path,
args.feature_path,
args.scaler_path,
args.validation_data,
run_id=run_id,
base_dir=src_path,
save_excel=save_excel,
separate_long_short=separate_long_short,
enhanced_charts=enhanced_charts,
show_plots=show_plots
)
logger.info("Backtest completed successfully")
print("\nBacktest completed successfully!")
except Exception as e:
logger.error(f"Error during backtesting: {str(e)}", exc_info=True)
print(f"Error during backtesting: {str(e)}")
return 1
if args.mode in ['comprehensive', 'all']:
logger.info("Starting Comprehensive Backtesting")
print("\n===== Starting Comprehensive Backtesting (Training + Validation) =====")
# Set default paths if not specified
if args.model_path is None:
args.model_path = os.path.join(models_dir, f"{args.model}.joblib")
if args.feature_path is None:
args.feature_path = os.path.join(models_dir, f"features.xlsx")
if args.scaler_path is None:
args.scaler_path = os.path.join(models_dir, f"scaler.joblib")
try:
# Run comprehensive backtest
training_metrics, validation_metrics = run_comprehensive_backtest(
args.model_path,
args.feature_path,
args.scaler_path,
args.data, # Training data
args.validation_data,
base_dir=src_path,
save_excel=save_excel,
separate_long_short=separate_long_short,
enhanced_charts=enhanced_charts,
show_plots=show_plots
)
logger.info("Comprehensive backtest completed successfully")
print("\nComprehensive backtest completed successfully!")
except Exception as e:
logger.error(f"Error during comprehensive backtesting: {str(e)}", exc_info=True)
print(f"Error during comprehensive backtesting: {str(e)}")
return 1
if args.mode in ['simulate', 'all']:
logger.info("Starting Real-time Trading Simulation")
print("\n===== Starting Real-time Trading Simulation =====")
# Set default paths if not specified
if args.model_path is None:
args.model_path = os.path.join(models_dir, f"{args.model}.joblib")
if args.feature_path is None:
args.feature_path = os.path.join(models_dir, f"features.xlsx")
if args.scaler_path is None:
args.scaler_path = os.path.join(models_dir, f"scaler.joblib")
# Validate thresholds
try:
from src.validation import validate_probability_threshold
validate_probability_threshold(args.long_threshold, 'long_threshold')
validate_probability_threshold(args.short_threshold, 'short_threshold')
except ValidationError as e:
logger.error(f"Threshold validation error: {str(e)}")
print(f"Error: {str(e)}")
return 1
try:
# Run real-time trading simulation with enhanced parameters
trade_df = simulate_real_time_trading(
args.model_path,
args.feature_path,
args.scaler_path,
args.validation_data,
args.long_threshold,
args.short_threshold,
run_id=run_id,
base_dir=src_path,
save_excel=save_excel,
enhanced_charts=enhanced_charts,
show_plots=show_plots
)
if trade_df is not None and len(trade_df) > 0:
logger.info("Simulation completed successfully")
print("\nSimulation completed successfully!")
else:
logger.warning("Simulation completed but no trades were generated")
print("\nSimulation completed, but no trades were generated with the specified thresholds.")
except Exception as e:
logger.error(f"Error during simulation: {str(e)}", exc_info=True)
print(f"Error during simulation: {str(e)}")
return 1
logger.info("Program execution completed successfully")
print("\nProgram execution completed!")
return 0
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
print(f"\nUnexpected error: {str(e)}")
return 1
if __name__ == "__main__":
import sys
exit_code = main()
sys.exit(exit_code if exit_code is not None else 0)