1+ # AI/modules/trader/backtest/run_backtrader_single.py
2+ """
3+ [Backtrader 기반 단일 종목 정밀 백테스트]
4+ - Walk-Forward Validation 지원
5+ - strategies/rule_based.py 의 RuleBasedStrategy 클래스 사용
6+ - AI Score 시각화 기능 포함
7+ """
8+
9+ import sys
10+ import os
11+ import backtrader as bt
12+ import pandas as pd
13+ import numpy as np
14+
15+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
16+ project_root = os .path .abspath (os .path .join (current_dir , "../../../.." ))
17+ if project_root not in sys .path :
18+ sys .path .append (project_root )
19+
20+ from AI .modules .signal .core .data_loader import SignalDataLoader
21+ from AI .modules .signal .models import get_model
22+ # ★ [수정] 클래스 기반 전략 불러오기
23+ from AI .modules .trader .strategies .rule_based import RuleBasedStrategy
24+
25+ class AIScoreObserver (bt .Observer ):
26+ """차트 하단에 AI 모델 점수를 그리기 위한 클래스"""
27+ lines = ('score' , 'limit_buy' , 'limit_sell' )
28+ plotinfo = dict (plot = True , subplot = True , plotname = 'AI Probability' )
29+ plotlines = dict (
30+ score = dict (marker = 'o' , markersize = 3.0 , color = 'blue' , _fill_gt = (0.5 , 'red' ), _fill_lt = (0.5 , 'green' )),
31+ limit_buy = dict (color = 'red' , linestyle = '--' ),
32+ limit_sell = dict (color = 'green' , linestyle = '--' )
33+ )
34+
35+ def next (self ):
36+ score = getattr (self ._owner , 'current_score' , 0.5 )
37+ self .lines .score [0 ] = score
38+ self .lines .limit_buy [0 ] = 0.65
39+ self .lines .limit_sell [0 ] = 0.40
40+
41+ class TransformerWalkForwardStrategy (bt .Strategy ):
42+ params = (
43+ ('model_weights_path' , None ),
44+ ('raw_df' , None ),
45+ ('features' , None ),
46+ ('loader' , None ),
47+ ('seq_len' , 60 ),
48+ ('model_name' , "transformer" ),
49+ )
50+
51+ def __init__ (self ):
52+ self .model = self ._load_model ()
53+ self .order = None
54+ self .current_score = 0.5
55+ # ★ [수정] 전략 객체 초기화
56+ self .strategy_logic = RuleBasedStrategy (buy_threshold = 0.65 , sell_threshold = 0.40 )
57+
58+ def log (self , txt , dt = None ):
59+ dt = dt or self .datas [0 ].datetime .date (0 )
60+ print (f'[{ dt .isoformat ()} ] { txt } ' )
61+
62+ def _load_model (self ):
63+ path = self .p .model_weights_path
64+ if not path or not os .path .exists (path ):
65+ self .log ("⚠️ 모델 가중치 파일 없음." )
66+ return None
67+
68+ default_config = {
69+ "head_size" : 256 , "num_heads" : 4 , "ff_dim" : 4 ,
70+ "num_blocks" : 4 , "mlp_units" : [128 ], "dropout" : 0.1
71+ }
72+ try :
73+ model = get_model (self .p .model_name , default_config )
74+ model .build ((None , self .p .seq_len , len (self .p .features )))
75+ if hasattr (model , 'model' ):
76+ model .model .load_weights (path )
77+ else :
78+ model .load_weights (path )
79+ return model
80+ except Exception as e :
81+ self .log (f"⚠️ 모델 로드 에러: { e } " )
82+ return None
83+
84+ def notify_order (self , order ):
85+ if order .status in [order .Completed ]:
86+ if order .isbuy ():
87+ self .log (f"🔵 BUY 체결 @ { order .executed .price :,.0f} " )
88+ elif order .issell ():
89+ self .log (f"🔴 SELL 체결 @ { order .executed .price :,.0f} " )
90+ self .order = None
91+
92+ def next (self ):
93+ if len (self ) < self .p .seq_len :
94+ return
95+
96+ current_date = self .datas [0 ].datetime .datetime (0 )
97+ past_data = self .p .raw_df .loc [:current_date ]
98+ if len (past_data ) < self .p .seq_len :
99+ return
100+
101+ # 1. Walk-Forward Prediction
102+ self .p .loader .scaler .fit (past_data [self .p .features ])
103+ recent_data = past_data .iloc [- self .p .seq_len :]
104+ input_seq = np .expand_dims (self .p .loader .scaler .transform (recent_data [self .p .features ]), axis = 0 )
105+
106+ if self .model :
107+ pred = self .model .predict (input_seq , verbose = 0 )
108+ score = float (pred [0 ][0 ])
109+ else :
110+ score = 0.5
111+
112+ self .current_score = score
113+
114+ # 2. 매매 판단 (RuleBasedStrategy 사용)
115+ if self .order : return # 이미 주문 중이면 패스
116+
117+ position_qty = self .position .size
118+ # ★ [수정] 클래스 메서드 호출로 변경 (코드가 훨씬 깔끔해짐)
119+ decision = self .strategy_logic .get_action (score , position_qty )
120+
121+ if decision ['type' ] == 'BUY' :
122+ # 보유 현금의 95%만큼 매수 계산 (Backtrader 로직)
123+ cash = self .broker .get_cash ()
124+ price = self .datas [0 ].close [0 ]
125+ # 수수료 고려하여 안전하게 계산
126+ size = int ((cash * 0.95 ) / price )
127+ if size > 0 :
128+ self .log (f"BUY 신호 (Score: { score :.2f} )" )
129+ self .order = self .buy (size = size )
130+
131+ elif decision ['type' ] == 'SELL' :
132+ if position_qty > 0 :
133+ self .log (f"SELL 신호 (Score: { score :.2f} )" )
134+ self .order = self .close () # 전량 청산
135+
136+ def run_single_backtest (ticker = "AAPL" , start_date = "2024-01-01" , end_date = "2024-06-01" , enable_plot = True ):
137+ print (f"\n === [{ ticker } ] 단일 종목 백테스트 시작 ===" )
138+
139+ weight_path = os .path .join (project_root , "AI/data/weights/transformer/universal_transformer.keras" )
140+ loader = SignalDataLoader (sequence_length = 60 )
141+ df = loader .load_data (ticker , start_date , end_date )
142+
143+ if df is None or len (df ) < 100 :
144+ print ("❌ 데이터 로드 실패" )
145+ return
146+
147+ if 'date' in df .columns :
148+ df ['date' ] = pd .to_datetime (df ['date' ])
149+ df .set_index ('date' , inplace = True )
150+
151+ data_feed = bt .feeds .PandasData (dataname = df )
152+ features = df .select_dtypes (include = [np .number ]).columns .tolist ()
153+
154+ cerebro = bt .Cerebro ()
155+ cerebro .adddata (data_feed )
156+
157+ cerebro .addstrategy (
158+ TransformerWalkForwardStrategy ,
159+ model_weights_path = weight_path ,
160+ raw_df = df ,
161+ features = features ,
162+ loader = loader
163+ )
164+
165+ if enable_plot :
166+ cerebro .addobserver (AIScoreObserver )
167+
168+ cerebro .broker .setcash (10_000_000 )
169+ cerebro .broker .setcommission (commission = 0.0015 )
170+ cerebro .addanalyzer (bt .analyzers .SharpeRatio , _name = 'sharpe' , riskfreerate = 0.0 )
171+ cerebro .addanalyzer (bt .analyzers .DrawDown , _name = 'drawdown' )
172+
173+ print (f"💰 초기 자산: { cerebro .broker .getvalue ():,.0f} 원" )
174+ results = cerebro .run ()
175+
176+ strat = results [0 ]
177+ final_val = cerebro .broker .getvalue ()
178+ mdd = strat .analyzers .drawdown .get_analysis ()['max' ]['drawdown' ]
179+ sharpe = strat .analyzers .sharpe .get_analysis ().get ('sharperatio' , 0.0 )
180+
181+ print (f"💰 최종 자산: { final_val :,.0f} 원 ({ (final_val / 10000000 - 1 )* 100 :.2f} %)" )
182+ print (f"📉 MDD: { mdd :.2f} % | 📊 Sharpe: { sharpe :.4f} " )
183+
184+ if enable_plot :
185+ cerebro .plot (style = 'candlestick' , volume = False )
186+
187+ if __name__ == "__main__" :
188+ run_single_backtest ()
0 commit comments