Skip to content

Commit ae48e58

Browse files
authored
[AI] [FIX] infer_input_shape_from_legacy_h5() 메서드 추가 (#333)
* [AI] [FIX] infer_input_shape_from_legacy_h5() 메서드 추가 * [AI] [FIX] 매매로직 치명적 버그 수정 * [AI] [FEAT] 백테스트 코드 작업 및 로컬 LLM 모델 변경 * [AI] [REFACT] 데이터 정리/검증 동작 강화
1 parent 490059e commit ae48e58

19 files changed

Lines changed: 1649 additions & 424 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ __pycache__/
3131
AI/.venv/
3232
AI/data/weights/tcn/
3333
AI/config/trading.local.json
34+
AI/tests/out/
3435

3536
# ===== Backend =====
3637
backend/src/main/java/org/sejongisc/backend/stock/TestController.java

AI/config/watchlist.json

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,33 @@
22
"tickers": [
33
"NVDA",
44
"TSLA",
5-
"MU",
6-
"SNDK",
7-
"MSFT",
8-
"AVGO",
95
"AAPL",
10-
"AMZN",
6+
"PLTR",
7+
"MSFT",
118
"META",
9+
"AMZN",
1210
"GOOGL",
13-
"PLTR",
14-
"AMD",
15-
"ORCL",
11+
"AVGO",
12+
"SMCI",
13+
"APP",
1614
"GOOG",
17-
"INTC",
18-
"XOM",
1915
"NFLX",
20-
"CVX",
21-
"JPM",
22-
"WDC",
23-
"CRM",
16+
"AMD",
17+
"UNH",
2418
"LLY",
25-
"WMT",
26-
"LRCX",
27-
"APP",
19+
"CRM",
20+
"BRK-B",
21+
"V",
2822
"COIN",
23+
"JPM",
24+
"WMT",
25+
"INTC",
26+
"HOOD",
27+
"COST",
28+
"MU",
2929
"BAC",
30-
"ADBE",
31-
"BRK-B",
32-
"V"
30+
"BKNG",
31+
"GEV",
32+
"HD"
3333
]
3434
}

AI/libs/database/repository.py

Lines changed: 174 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- 기존의 개별 함수들을 PortfolioRepository 클래스로 통합하여 응집도를 높이고 DB 커넥션 관리를 최적화했습니다.
66
"""
77

8+
import os
89
from typing import List, Tuple, Optional, Any, Dict
910
import pandas as pd
1011
import psycopg2
@@ -108,7 +109,8 @@ def get_current_position(self, ticker: str, target_date: str = None, initial_cas
108109
params.append(target_date)
109110

110111
# 시간순으로 정렬하여 정확한 롤링 계산을 수행합니다.
111-
query += " ORDER BY fill_date ASC, created_at ASC"
112+
# Deterministic ordering is important for reproducible position/cost accounting.
113+
query += " ORDER BY fill_date ASC, created_at ASC, id ASC"
112114

113115
try:
114116
cursor.execute(query, tuple(params))
@@ -171,42 +173,138 @@ def get_current_position(self, ticker: str, target_date: str = None, initial_cas
171173

172174
def get_current_cash(self, target_date: str = None, initial_cash: float = 10000000) -> float:
173175
"""
174-
[현재 포트폴리오 현금 조회]
175-
포트폴리오 요약 테이블에서 target_date 이전의 가장 최근 현금 잔고를 조회합니다.
176-
177-
Args:
178-
target_date (str, optional): 기준 날짜
179-
initial_cash (float): 내역이 없을 경우 반환할 초기 자본금
180-
181-
Returns:
182-
float: 계산된 현재 현금 잔고
176+
executions 테이블의 cash_after를 우선 사용해 기준 시점 현금을 조회합니다.
177+
데이터가 없거나 조회에 실패하면 보수적으로 초기 현금(initial_cash)을 반환합니다.
183178
"""
184179
conn = self._get_connection()
185180
if conn is None:
186-
return initial_cash
187-
188-
cursor = conn.cursor()
189-
190-
# 기준일 이전의 가장 최근 마감 데이터를 가져오는 서브쿼리 활용
181+
return float(initial_cash)
182+
183+
try:
184+
with conn.cursor() as cursor:
185+
if target_date:
186+
exec_cash_query = """
187+
SELECT cash_after
188+
FROM public.executions
189+
WHERE fill_date <= %s
190+
ORDER BY fill_date DESC, created_at DESC, id DESC
191+
LIMIT 1
192+
"""
193+
cursor.execute(exec_cash_query, (target_date,))
194+
exec_cash = cursor.fetchone()
195+
if exec_cash and exec_cash[0] is not None:
196+
return float(exec_cash[0])
197+
198+
summary_cash_query = """
199+
SELECT cash
200+
FROM public.portfolio_summary
201+
WHERE date = (
202+
SELECT MAX(date)
203+
FROM public.portfolio_summary
204+
WHERE date < %s
205+
)
206+
LIMIT 1
207+
"""
208+
cursor.execute(summary_cash_query, (target_date,))
209+
summary_cash = cursor.fetchone()
210+
if summary_cash and summary_cash[0] is not None:
211+
return float(summary_cash[0])
212+
return float(initial_cash)
213+
except Exception as e:
214+
print(f"[PortfolioRepository][Error] 현재 현금 조회 실패: {e}")
215+
return float(initial_cash)
216+
finally:
217+
conn.close()
218+
219+
def get_open_tickers(self, target_date: str) -> List[str]:
220+
"""
221+
target_date 이전까지의 누적 순수량(net 포지션)이 0보다 큰 티커 목록을 반환합니다.
222+
"""
223+
conn = self._get_connection()
224+
if conn is None:
225+
return []
226+
191227
query = """
192-
SELECT cash
193-
FROM public.portfolio_summary
194-
WHERE date = (SELECT MAX(date) FROM public.portfolio_summary WHERE date < %s)
195-
LIMIT 1;
228+
SELECT ticker
229+
FROM public.executions
230+
WHERE fill_date <= %s
231+
GROUP BY ticker
232+
HAVING SUM(
233+
CASE
234+
WHEN side = 'BUY' THEN qty
235+
WHEN side = 'SELL' THEN -qty
236+
ELSE 0
237+
END
238+
) > 0
239+
ORDER BY ticker
196240
"""
197-
241+
242+
try:
243+
with conn.cursor() as cursor:
244+
cursor.execute(query, (target_date,))
245+
rows = cursor.fetchall()
246+
return [str(row[0]) for row in rows]
247+
except Exception as e:
248+
print(f"[PortfolioRepository][Error] Open 티커 목록 조회 실패: {e}")
249+
return []
250+
finally:
251+
conn.close()
252+
253+
def reset_run_data(self, run_id: str, target_date: Optional[str] = None) -> None:
254+
"""
255+
Remove stale simulation artifacts before a rerun.
256+
257+
If target_date is provided, downstream dates are also removed so cash/position
258+
chains are recalculated consistently from target_date forward.
259+
"""
260+
if not run_id:
261+
return
262+
263+
conn = self._get_connection()
264+
if conn is None:
265+
return
266+
198267
try:
199-
cursor.execute(query, (target_date,))
200-
result = cursor.fetchone()
201-
if result:
202-
return float(result[0])
268+
with conn.cursor() as cursor:
269+
if target_date:
270+
# Safety-first default: only clear current run artifacts.
271+
# Global chain reset can remove unrelated simulations sharing the same DB.
272+
allow_global_chain_reset = os.environ.get("AI_ALLOW_GLOBAL_CHAIN_RESET", "0") == "1"
273+
if allow_global_chain_reset:
274+
cursor.execute(
275+
"DELETE FROM public.executions WHERE fill_date >= %s AND run_id LIKE 'daily_%%'",
276+
(target_date,),
277+
)
278+
cursor.execute(
279+
"DELETE FROM public.xai_reports WHERE date >= %s AND run_id LIKE 'daily_%%'",
280+
(target_date,),
281+
)
282+
cursor.execute("DELETE FROM public.portfolio_positions WHERE date >= %s", (target_date,))
283+
cursor.execute("DELETE FROM public.portfolio_summary WHERE date >= %s", (target_date,))
284+
else:
285+
cursor.execute("DELETE FROM public.executions WHERE run_id = %s", (run_id,))
286+
cursor.execute("DELETE FROM public.xai_reports WHERE run_id = %s", (run_id,))
287+
else:
288+
cursor.execute("DELETE FROM public.executions WHERE run_id = %s", (run_id,))
289+
cursor.execute("DELETE FROM public.xai_reports WHERE run_id = %s", (run_id,))
290+
conn.commit()
291+
if target_date:
292+
if os.environ.get("AI_ALLOW_GLOBAL_CHAIN_RESET", "0") == "1":
293+
print(
294+
f"[PortfolioRepository] Reset simulation rows from {target_date} onward "
295+
f"(triggered by run_id={run_id})."
296+
)
297+
else:
298+
print(
299+
f"[PortfolioRepository] Reset current run artifacts only "
300+
f"(run_id={run_id}, target_date={target_date})."
301+
)
203302
else:
204-
return initial_cash
303+
print(f"[PortfolioRepository] Reset existing records for run_id={run_id}.")
205304
except Exception as e:
206-
print(f"[PortfolioRepository][Error] 포트폴리오 현금 조회 중 오류 발생: {e}")
207-
return initial_cash
305+
conn.rollback()
306+
print(f"[PortfolioRepository][Error] reset_run_data failed: {e}")
208307
finally:
209-
cursor.close()
210308
conn.close()
211309

212310
def save_executions_to_db(self, fills_df: pd.DataFrame) -> None:
@@ -231,9 +329,31 @@ def save_executions_to_db(self, fills_df: pd.DataFrame) -> None:
231329

232330
if not required_cols.issubset(fills_df.columns):
233331
missing = required_cols - set(fills_df.columns)
234-
print(f"[PortfolioRepository][Error] 체결 내역 데이터에 필수 컬럼이 누락되었습니다: {missing}")
332+
print(f"[PortfolioRepository][Error] Missing required execution columns: {missing}")
235333
return
236334

335+
def _normalize_run_id(value: Any) -> Optional[str]:
336+
if pd.isna(value):
337+
return None
338+
normalized = str(value).strip()
339+
if not normalized:
340+
return None
341+
if normalized.lower() in {"nan", "none", "null"}:
342+
return None
343+
return normalized
344+
345+
normalized_run_ids = fills_df["run_id"].apply(_normalize_run_id)
346+
missing_mask = normalized_run_ids.isna()
347+
if bool(missing_mask.any()):
348+
sample_rows = fills_df.loc[missing_mask, ["ticker", "signal_date", "fill_date"]].head(5)
349+
raise ValueError(
350+
"[PortfolioRepository][Error] run_id must be non-empty for all execution rows. "
351+
f"missing rows sample={sample_rows.to_dict(orient='records')}"
352+
)
353+
354+
fills_df = fills_df.copy()
355+
fills_df["run_id"] = normalized_run_ids
356+
237357
conn = self._get_connection()
238358
if conn is None:
239359
print("[PortfolioRepository][Error] DB 연결에 실패하여 체결 내역을 저장할 수 없습니다.")
@@ -242,6 +362,15 @@ def save_executions_to_db(self, fills_df: pd.DataFrame) -> None:
242362
cursor = conn.cursor()
243363

244364
try:
365+
run_ids = sorted(
366+
{
367+
str(run_id).strip()
368+
for run_id in fills_df["run_id"].tolist()
369+
if pd.notna(run_id) and str(run_id).strip()
370+
}
371+
)
372+
if run_ids:
373+
cursor.execute("DELETE FROM public.executions WHERE run_id = ANY(%s)", (run_ids,))
245374
# 다량의 데이터를 빠르게 넣기 위한 INSERT 구문 준비
246375
insert_query = """
247376
INSERT INTO public.executions (
@@ -296,7 +425,7 @@ def save_executions_to_db(self, fills_df: pd.DataFrame) -> None:
296425
cursor.close()
297426
conn.close()
298427

299-
def save_reports_to_db(self, reports_tuple_list: list) -> list:
428+
def save_reports_to_db(self, reports_tuple_list: list, run_id: Optional[str] = None) -> list:
300429
"""
301430
[XAI 리포트 일괄 저장]
302431
설명 가능한 AI(XAI) 분석 리포트들을 DB에 저장하고, 생성된 Primary Key(ID) 리스트를 반환합니다.
@@ -320,16 +449,26 @@ def save_reports_to_db(self, reports_tuple_list: list) -> list:
320449
# RETURNING id 절을 사용하여 INSERT 후 자동 생성된 PK를 반환받습니다.
321450
insert_query = """
322451
INSERT INTO public.xai_reports (
323-
ticker, signal, price, date, report
452+
ticker, signal, price, date, report, run_id
324453
) VALUES %s
454+
ON CONFLICT (ticker, date, signal) DO UPDATE
455+
SET price = EXCLUDED.price,
456+
report = EXCLUDED.report,
457+
run_id = EXCLUDED.run_id,
458+
created_at = NOW()
325459
RETURNING id
326460
"""
327461

328462
# fetch=True 옵션으로 execute_values 실행 시 RETURNING 결과를 리스트 형태로 모아줍니다.
463+
reports_with_run = []
464+
for row in reports_tuple_list:
465+
ticker, signal, price, date, report = row
466+
reports_with_run.append((ticker, signal, price, date, report, run_id))
467+
329468
result_ids = execute_values(
330-
cursor,
331-
insert_query,
332-
reports_tuple_list,
469+
cursor,
470+
insert_query,
471+
reports_with_run,
333472
fetch=True
334473
)
335474

@@ -419,4 +558,4 @@ def save_portfolio_positions(self, date: str, data_tuples: list):
419558
print(f"[PortfolioRepository][Error] 포지션 저장 실패: {e}")
420559
finally:
421560
cursor.close()
422-
conn.close()
561+
conn.close()

AI/libs/llm/base_client.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
# AI/libs/llm/base_client.py
22
"""
3-
[LLM 클라이언트 인터페이스]
4-
- 모든 LLM 서비스(Groq, Ollama 등)가 준수해야 할 공통 규약을 정의합니다.
5-
- 이를 통해 모델 교체 시 코드 수정을 최소화할 수 있습니다.
3+
Shared base interface for all LLM clients.
64
"""
75

86
from abc import ABC, abstractmethod
9-
from typing import Dict, Any, Optional
7+
from typing import Any, Optional
8+
109

1110
class BaseLLMClient(ABC):
12-
"""모든 LLM 클라이언트의 추상 기본 클래스"""
11+
"""Abstract base class for all LLM clients."""
1312

1413
def __init__(self, api_key: Optional[str] = None, model_name: str = "default"):
1514
self.api_key = api_key
1615
self.model_name = model_name
16+
self.last_error: Optional[str] = None
17+
18+
def clear_last_error(self) -> None:
19+
self.last_error = None
20+
21+
def set_last_error(self, error: Any) -> None:
22+
if isinstance(error, Exception):
23+
self.last_error = str(error) or error.__class__.__name__
24+
return
25+
self.last_error = str(error)
1726

1827
@abstractmethod
1928
def generate_text(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> str:
20-
"""
21-
단일 프롬프트를 입력받아 텍스트 응답을 생성합니다.
22-
23-
Args:
24-
prompt (str): 사용자 입력 프롬프트
25-
system_prompt (str, optional): 시스템 프롬프트 (역할 정의 등)
26-
**kwargs: 모델별 추가 파라미터 (temperature 등)
27-
28-
Returns:
29-
str: 생성된 텍스트
30-
"""
31-
pass
32-
29+
"""Generate text from prompt."""
30+
raise NotImplementedError
31+
3332
@abstractmethod
3433
def get_health(self) -> bool:
35-
"""서비스 상태를 확인합니다."""
36-
pass
34+
"""Return whether the LLM service is available."""
35+
raise NotImplementedError

0 commit comments

Comments
 (0)