Skip to content

Commit c701355

Browse files
authored
[AI] SISC-185 [FEAT] 강화학습 기본코드 작성 (#186)
* [AI] SISC-185 [FEAT] RL 기본코드 작성 및 조정종가 코드 수정 * [AI] SISC-185 [FEAT] RL 학습코드작성 * Update Simulator initialization with initial balance Add initial balance parameter to Simulator instantiation * [AI] SISC-185 [FIX] 초기자금 초기화 버그 수정
1 parent 2947f4d commit c701355

15 files changed

Lines changed: 762 additions & 660 deletions

File tree

AI/libs/database/fetcher.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# libs/database/fetcher.py
1+
# AI/libs/database/fetcher.py
22
from __future__ import annotations
33
from typing import Optional
44
import pandas as pd
@@ -15,25 +15,22 @@ def fetch_ohlcv(
1515
db_name: str = "db",
1616
) -> pd.DataFrame:
1717
"""
18-
특정 티커, 날짜 범위의 OHLCV 데이터를 DB에서 불러오기 (SQLAlchemy 엔진 사용)
18+
특정 티커, 날짜 범위의 OHLCV 데이터를 DB에서 불러오기
1919
2020
Args:
2121
ticker (str): 종목 코드 (예: "AAPL")
22-
start (str): 시작일자 'YYYY-MM-DD' (inclusive)
23-
end (str): 종료일자 'YYYY-MM-DD' (inclusive)
24-
interval (str): 데이터 간격 ('1d' 등) - 현재 테이블이 일봉만 제공하면 무시됨
25-
db_name (str): get_engine()가 참조할 설정 블록 이름 (예: "db", "report_DB")
22+
start (str): 시작일자 'YYYY-MM-DD'
23+
end (str): 종료일자 'YYYY-MM-DD'
24+
interval (str): 데이터 간격 (현재 일봉만 지원)
25+
db_name (str): DB 설정 이름
2626
2727
Returns:
28-
pd.DataFrame: 컬럼 = [ticker, date, open, high, low, close, adjusted_close, volume]
29-
(date 컬럼은 pandas datetime으로 변환됨)
28+
pd.DataFrame: [ticker, date, open, high, low, close, adjusted_close, volume]
3029
"""
3130

32-
# 1) SQLAlchemy engine 얻기 (configs/config.json 기준)
3331
engine = get_engine(db_name)
3432

35-
# 2) 쿼리: named parameter(:ticker 등) 사용 -> 안전하고 가독성 좋음
36-
# - interval 분기가 필요하면 테이블/파티션 구조에 따라 쿼리를 분기하도록 확장 가능
33+
# adjusted_close가 중요하다면 쿼리 단계에서 확실히 가져옵니다.
3734
query = text("""
3835
SELECT ticker, date, open, high, low, close, adjusted_close, volume
3936
FROM public.price_data
@@ -42,28 +39,32 @@ def fetch_ohlcv(
4239
ORDER BY date;
4340
""")
4441

45-
# 3) DB에서 읽기 (with 문으로 커넥션 자동 정리)
4642
with engine.connect() as conn:
4743
df = pd.read_sql(
4844
query,
49-
con=conn, # 꼭 키워드 인자로 con=conn
50-
params={"ticker": ticker, "start": start, "end": end}, # 튜플 X, 딕셔너리 O
51-
)
45+
con=conn,
46+
params={"ticker": ticker, "start": start, "end": end},
47+
)
5248

53-
# 4) 후처리: 컬럼 정렬 및 date 타입 통일
49+
# 빈 데이터 처리
5450
if df is None or df.empty:
55-
# 빈 DataFrame이면 일관된 컬럼 스키마로 반환
5651
return pd.DataFrame(columns=["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"])
5752

58-
# date 컬럼을 datetime으로 변경 (UTC로 맞추고 싶으면 pd.to_datetime(..., utc=True) 사용)
53+
# 날짜 변환
5954
if "date" in df.columns:
6055
df["date"] = pd.to_datetime(df["date"])
6156

62-
# 선택: 컬럼 순서 고정 (일관성 유지)
57+
# 데이터 보정 로직 추가
58+
# 1. adjusted_close가 없는 경우(NaN) -> close 값으로 대체 (결측치 방지)
59+
if "adjusted_close" in df.columns and "close" in df.columns:
60+
df["adjusted_close"] = df["adjusted_close"].fillna(df["close"])
61+
elif "adjusted_close" not in df.columns and "close" in df.columns:
62+
# 컬럼 자체가 없으면 close를 복사해서 생성
63+
df["adjusted_close"] = df["close"]
64+
65+
# 컬럼 순서 정리
6366
desired_cols = ["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"]
64-
# 존재하는 컬럼만 가져오기
6567
cols_present = [c for c in desired_cols if c in df.columns]
6668
df = df.loc[:, cols_present]
6769

68-
return df
69-
70+
return df

AI/modules/signal/core/features.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,78 @@
11
# AI/modules/signal/core/features.py
22
"""
3-
[피처 엔지니어링 모듈]
4-
- OHLCV 데이터를 입력받아 학습에 필요한 기술적 지표(RSI, MACD, 볼린저밴드 등)를 추가합니다.
5-
- 데이터 로더(DataLoader)에서 이 함수를 호출하여 전처리를 수행합니다.
3+
[피처 엔지니어링 모듈 - Adjusted Close 통합 버전]
4+
- 데이터에 'adjusted_close'가 있다면 이를 'close'로 덮어씌웁니다.
5+
- 이렇게 하면 모든 지표(RSI, MACD 등)가 자연스럽게 '조정 종가' 기준으로 계산됩니다.
6+
- 학습 시 'close'와 'adjusted_close'가 중복되는 문제도 해결됩니다.
67
"""
78

89
import pandas as pd
910
import numpy as np
1011

1112
def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
1213
"""
13-
데이터프레임에 기술적 지표 컬럼을 추가합니다.
14-
15-
Args:
16-
df (pd.DataFrame): OHLCV 데이터 (필수 컬럼: 'close', 'high', 'low', 'volume')
17-
18-
Returns:
19-
pd.DataFrame: 지표가 추가된 데이터프레임
14+
1. 조정 종가(Adjusted Close)를 종가(Close)로 통합합니다.
15+
2. 기술적 지표를 계산하여 추가합니다.
2016
"""
2117
if df.empty:
2218
return df
2319

2420
df = df.copy()
2521

22+
# ★ [핵심 수정] 조정 종가 우선 정책
23+
# adjusted_close가 있으면, 이를 close에 덮어쓰고 adjusted_close 컬럼은 삭제합니다.
24+
if 'adjusted_close' in df.columns:
25+
# 결측치 방지 (혹시 adjusted_close가 비어있으면 close 값 사용)
26+
df['adjusted_close'] = df['adjusted_close'].fillna(df['close'])
27+
28+
# 덮어쓰기
29+
df['close'] = df['adjusted_close']
30+
31+
# 중복 방지를 위해 삭제 (이제 close가 adjusted_close 역할을 함)
32+
df.drop(columns=['adjusted_close'], inplace=True)
33+
34+
# --- 이하 모든 계산은 'close'(실제로는 조정 종가)를 기준으로 수행됨 ---
35+
2636
# 1. 이동평균선 (Simple Moving Average)
2737
df['ma5'] = df['close'].rolling(window=5).mean()
2838
df['ma20'] = df['close'].rolling(window=20).mean()
2939
df['ma60'] = df['close'].rolling(window=60).mean()
3040

3141
# 2. RSI (Relative Strength Index)
32-
# CodeRabbit 리뷰 반영: 엣지 케이스(횡보, 상승지속) 정밀 처리
3342
delta = df['close'].diff()
3443
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
3544
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
3645

37-
# division by zero 등 경고 억제
3846
with np.errstate(divide='ignore', invalid='ignore'):
3947
rs = gain / loss
4048

41-
# 기본 RSI 계산 (loss=0인 경우 rs=inf가 되며, 100/(1+inf)=0 이므로 RSI=100이 됨 -> 정상)
4249
df['rsi'] = 100 - (100 / (1 + rs))
4350

44-
# [보정 1] 가격 변동이 아예 없는 경우 (Gain=0, Loss=0) -> NaN 발생 -> 50(중립)으로 설정
51+
# RSI 보정
4552
df.loc[(gain == 0) & (loss == 0), 'rsi'] = 50
46-
47-
# [보정 2] 하락 없이 상승만 한 경우 (Loss=0, Gain>0) -> 100(강세)으로 설정 (수식상 자동 처리되나 명시)
4853
df.loc[(loss == 0) & (gain > 0), 'rsi'] = 100
4954

5055
# 3. 볼린저 밴드 (Bollinger Bands)
5156
df['std20'] = df['close'].rolling(window=20).std()
5257
df['upper_band'] = df['ma20'] + (df['std20'] * 2)
5358
df['lower_band'] = df['ma20'] - (df['std20'] * 2)
5459

55-
# 4. MACD (Moving Average Convergence Divergence)
60+
# 4. MACD
5661
exp12 = df['close'].ewm(span=12, adjust=False).mean()
5762
exp26 = df['close'].ewm(span=26, adjust=False).mean()
5863
df['macd'] = exp12 - exp26
5964
df['signal_line'] = df['macd'].ewm(span=9, adjust=False).mean()
6065

6166
# 5. 거래량 변화율
62-
df['vol_change'] = df['volume'].pct_change()
67+
if 'volume' in df.columns:
68+
df['vol_change'] = df['volume'].pct_change()
69+
df['vol_change'] = df['vol_change'].replace([np.inf, -np.inf], 0)
70+
else:
71+
df['vol_change'] = 0
6372

64-
# 6. 결측치 처리 (지표 계산 초반 구간)
65-
# [수정] FutureWarning 해결: fillna(method='bfill') -> bfill()
73+
# === [데이터 정제] ===
74+
df.replace([np.inf, -np.inf], np.nan, inplace=True)
6675
df = df.bfill()
67-
df = df.fillna(0) # 앞부분 bfill로도 안 채워지는 경우 0 처리
76+
df = df.fillna(0)
6877

6978
return df
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
[Backtest Execution Package]
3+
- 단일 종목 및 포트폴리오 단위의 백테스트 실행 함수들을 제공합니다.
4+
"""
5+
6+
# 함수 이름이 겹치지 않게 alias(별칭)를 주어 명확히 구분합니다.
7+
from .run_portfolio import run_backtest as run_portfolio_backtest
8+
from .run_backtrader_single import run_single_backtest
9+
10+
__all__ = ['run_portfolio_backtest', 'run_single_backtest']
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# AI/modules/trader/backtest/run_backtrader_single.py
2+
"""
3+
[Backtrader 기반 단일 종목 정밀 백테스트]
4+
- Walk-Forward Validation 지원
5+
- strategies/rule_based.py 의 RuleBasedStrategy 클래스 사용
6+
- AI Score 시각화 기능 포함
7+
"""
8+
9+
import sys
10+
import os
11+
import backtrader as bt
12+
import pandas as pd
13+
import numpy as np
14+
15+
current_dir = os.path.dirname(os.path.abspath(__file__))
16+
project_root = os.path.abspath(os.path.join(current_dir, "../../../.."))
17+
if project_root not in sys.path:
18+
sys.path.append(project_root)
19+
20+
from AI.modules.signal.core.data_loader import SignalDataLoader
21+
from AI.modules.signal.models import get_model
22+
# ★ [수정] 클래스 기반 전략 불러오기
23+
from AI.modules.trader.strategies.rule_based import RuleBasedStrategy
24+
25+
class AIScoreObserver(bt.Observer):
26+
"""차트 하단에 AI 모델 점수를 그리기 위한 클래스"""
27+
lines = ('score', 'limit_buy', 'limit_sell')
28+
plotinfo = dict(plot=True, subplot=True, plotname='AI Probability')
29+
plotlines = dict(
30+
score=dict(marker='o', markersize=3.0, color='blue', _fill_gt=(0.5, 'red'), _fill_lt=(0.5, 'green')),
31+
limit_buy=dict(color='red', linestyle='--'),
32+
limit_sell=dict(color='green', linestyle='--')
33+
)
34+
35+
def next(self):
36+
score = getattr(self._owner, 'current_score', 0.5)
37+
self.lines.score[0] = score
38+
self.lines.limit_buy[0] = 0.65
39+
self.lines.limit_sell[0] = 0.40
40+
41+
class TransformerWalkForwardStrategy(bt.Strategy):
42+
params = (
43+
('model_weights_path', None),
44+
('raw_df', None),
45+
('features', None),
46+
('loader', None),
47+
('seq_len', 60),
48+
('model_name', "transformer"),
49+
)
50+
51+
def __init__(self):
52+
self.model = self._load_model()
53+
self.order = None
54+
self.current_score = 0.5
55+
# ★ [수정] 전략 객체 초기화
56+
self.strategy_logic = RuleBasedStrategy(buy_threshold=0.65, sell_threshold=0.40)
57+
58+
def log(self, txt, dt=None):
59+
dt = dt or self.datas[0].datetime.date(0)
60+
print(f'[{dt.isoformat()}] {txt}')
61+
62+
def _load_model(self):
63+
path = self.p.model_weights_path
64+
if not path or not os.path.exists(path):
65+
self.log("⚠️ 모델 가중치 파일 없음.")
66+
return None
67+
68+
default_config = {
69+
"head_size": 256, "num_heads": 4, "ff_dim": 4,
70+
"num_blocks": 4, "mlp_units": [128], "dropout": 0.1
71+
}
72+
try:
73+
model = get_model(self.p.model_name, default_config)
74+
model.build((None, self.p.seq_len, len(self.p.features)))
75+
if hasattr(model, 'model'):
76+
model.model.load_weights(path)
77+
else:
78+
model.load_weights(path)
79+
return model
80+
except Exception as e:
81+
self.log(f"⚠️ 모델 로드 에러: {e}")
82+
return None
83+
84+
def notify_order(self, order):
85+
if order.status in [order.Completed]:
86+
if order.isbuy():
87+
self.log(f"🔵 BUY 체결 @ {order.executed.price:,.0f}")
88+
elif order.issell():
89+
self.log(f"🔴 SELL 체결 @ {order.executed.price:,.0f}")
90+
self.order = None
91+
92+
def next(self):
93+
if len(self) < self.p.seq_len:
94+
return
95+
96+
current_date = self.datas[0].datetime.datetime(0)
97+
past_data = self.p.raw_df.loc[:current_date]
98+
if len(past_data) < self.p.seq_len:
99+
return
100+
101+
# 1. Walk-Forward Prediction
102+
self.p.loader.scaler.fit(past_data[self.p.features])
103+
recent_data = past_data.iloc[-self.p.seq_len:]
104+
input_seq = np.expand_dims(self.p.loader.scaler.transform(recent_data[self.p.features]), axis=0)
105+
106+
if self.model:
107+
pred = self.model.predict(input_seq, verbose=0)
108+
score = float(pred[0][0])
109+
else:
110+
score = 0.5
111+
112+
self.current_score = score
113+
114+
# 2. 매매 판단 (RuleBasedStrategy 사용)
115+
if self.order: return # 이미 주문 중이면 패스
116+
117+
position_qty = self.position.size
118+
# ★ [수정] 클래스 메서드 호출로 변경 (코드가 훨씬 깔끔해짐)
119+
decision = self.strategy_logic.get_action(score, position_qty)
120+
121+
if decision['type'] == 'BUY':
122+
# 보유 현금의 95%만큼 매수 계산 (Backtrader 로직)
123+
cash = self.broker.get_cash()
124+
price = self.datas[0].close[0]
125+
# 수수료 고려하여 안전하게 계산
126+
size = int((cash * 0.95) / price)
127+
if size > 0:
128+
self.log(f"BUY 신호 (Score: {score:.2f})")
129+
self.order = self.buy(size=size)
130+
131+
elif decision['type'] == 'SELL':
132+
if position_qty > 0:
133+
self.log(f"SELL 신호 (Score: {score:.2f})")
134+
self.order = self.close() # 전량 청산
135+
136+
def run_single_backtest(ticker="AAPL", start_date="2024-01-01", end_date="2024-06-01", enable_plot=True):
137+
print(f"\n=== [{ticker}] 단일 종목 백테스트 시작 ===")
138+
139+
weight_path = os.path.join(project_root, "AI/data/weights/transformer/universal_transformer.keras")
140+
loader = SignalDataLoader(sequence_length=60)
141+
df = loader.load_data(ticker, start_date, end_date)
142+
143+
if df is None or len(df) < 100:
144+
print("❌ 데이터 로드 실패")
145+
return
146+
147+
if 'date' in df.columns:
148+
df['date'] = pd.to_datetime(df['date'])
149+
df.set_index('date', inplace=True)
150+
151+
data_feed = bt.feeds.PandasData(dataname=df)
152+
features = df.select_dtypes(include=[np.number]).columns.tolist()
153+
154+
cerebro = bt.Cerebro()
155+
cerebro.adddata(data_feed)
156+
157+
cerebro.addstrategy(
158+
TransformerWalkForwardStrategy,
159+
model_weights_path=weight_path,
160+
raw_df=df,
161+
features=features,
162+
loader=loader
163+
)
164+
165+
if enable_plot:
166+
cerebro.addobserver(AIScoreObserver)
167+
168+
cerebro.broker.setcash(10_000_000)
169+
cerebro.broker.setcommission(commission=0.0015)
170+
cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.0)
171+
cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
172+
173+
print(f"💰 초기 자산: {cerebro.broker.getvalue():,.0f}원")
174+
results = cerebro.run()
175+
176+
strat = results[0]
177+
final_val = cerebro.broker.getvalue()
178+
mdd = strat.analyzers.drawdown.get_analysis()['max']['drawdown']
179+
sharpe = strat.analyzers.sharpe.get_analysis().get('sharperatio', 0.0)
180+
181+
print(f"💰 최종 자산: {final_val:,.0f}원 ({(final_val/10000000 - 1)*100:.2f}%)")
182+
print(f"📉 MDD: {mdd:.2f}% | 📊 Sharpe: {sharpe:.4f}")
183+
184+
if enable_plot:
185+
cerebro.plot(style='candlestick', volume=False)
186+
187+
if __name__ == "__main__":
188+
run_single_backtest()

0 commit comments

Comments
 (0)