Skip to content

Commit 9db895f

Browse files
authored
Merge pull request #42 from SynergyX-AI-Pattern/refactor/#41_predict_batch_avg
[REFACTOR] 예측 배치에 top 20 계산 및 저장 로직 추가 #41
2 parents 22dc93b + 1b86658 commit 9db895f

File tree

12 files changed

+224
-9
lines changed

12 files changed

+224
-9
lines changed

app/core/logging_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ def setup_logging():
1010
level=LOG_LEVEL,
1111
format=LOG_FORMAT,
1212
stream=sys.stdout,
13+
force=True,
1314
)

app/db/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1-
from sqlalchemy.orm import declarative_base
1+
from app.db.base_class import Base
22

3-
Base = declarative_base()
3+
from app.models.stock import Stock
4+
from app.models.stock_detail import StockDetail
5+
from app.models.prediction import Prediction
6+
from app.models.pattern import Pattern
7+
from app.models.pattern_apply import PatternApply
8+
from app.models.pattern_detection_log import PatternDetectionLog

app/db/base_class.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from sqlalchemy.orm import declarative_base
2+
3+
Base = declarative_base()

app/db/session.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from sqlalchemy import create_engine
22
from sqlalchemy.orm import sessionmaker
33
from app.core.config import settings
4+
from app.db.base import Base
45

6+
7+
# DB 엔진 생성
58
engine = create_engine(
69
settings.DATABASE_URL,
7-
pool_pre_ping=True, # 연결 확인 (죽은 연결 자동 제거)
8-
echo=False # SQL 로그 사용
10+
pool_pre_ping=True,
11+
echo=False,
912
)
1013

14+
# 세션팩토리
1115
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
1216

1317

18+
# 의존성 주입용 (FastAPI DI)
1419
def get_db():
1520
db = SessionLocal()
1621
try:

app/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy import Column, DateTime
22
from sqlalchemy.sql import func
3-
from app.db.base import Base
3+
from app.db.base_class import Base
44

55

66
class BaseTimeModel(Base):

app/models/stock.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ class Stock(BaseTimeModel):
1414
# 관계 설정
1515
predictions = relationship("Prediction", back_populates="stock")
1616
pattern_applies = relationship("PatternApply", back_populates="stock")
17+
stock_detail = relationship("StockDetail", back_populates="stock", uselist=False)

app/models/stock_detail.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sqlalchemy import Column, BigInteger, Float, JSON, ForeignKey
2+
from app.models.base_model import BaseTimeModel
3+
from sqlalchemy.orm import relationship, Mapped
4+
5+
class StockDetail(BaseTimeModel):
6+
__tablename__ = "stock_detail"
7+
8+
id = Column(BigInteger, primary_key=True, autoincrement=True)
9+
stock_id = Column(BigInteger, ForeignKey("stock.id"), nullable=False, unique=True)
10+
11+
price = Column(Float, nullable=True)
12+
change_rate = Column(Float, nullable=True)
13+
financial_data = Column(JSON, nullable=True)
14+
change_amount = Column(Float, nullable=True)
15+
version = Column(BigInteger, nullable=True)
16+
17+
ai_avg_increase = Column(Float, nullable=True, default=None)
18+
ai_rank = Column(BigInteger, nullable=True, default=None)
19+
20+
# 관계
21+
stock = relationship("Stock", back_populates="stock_detail")

app/schedulers/predict_batch_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def run_batch():
3535
cooldown_sec=cooldown_sec
3636
)
3737

38+
# 평균 상승률, 랭크 계산
39+
top3_info = BatchService.update_ai_avg_increase_and_rank()
40+
3841
duration = time.time() - start_time
3942
logger.info(
4043
f"[Scheduler] 배치 예측 완료 - 성공: {success_count}, 실패: {fail_count}, 소요: {duration:.2f}s"
@@ -46,7 +49,7 @@ def run_batch():
4649
import threading
4750
def run_notification():
4851
asyncio.run(
49-
notify_discord_async(success_count, fail_count, duration, failed_symbols, 1404342496221462539))
52+
notify_discord_async(success_count, fail_count, duration, failed_symbols, 1404342496221462539, top3_info=top3_info,))
5053

5154
notification_thread = threading.Thread(target=run_notification)
5255
notification_thread.start()

app/services/batch_service.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from concurrent.futures import ThreadPoolExecutor, as_completed
2-
from sqlalchemy.orm import Session
2+
from sqlalchemy.orm import Session, aliased
3+
from sqlalchemy import func
34
from app.db.session import SessionLocal
45
from app.models.stock import Stock
6+
from app.models.stock_detail import StockDetail
7+
from app.models.prediction import Prediction
58
import time
9+
import datetime
610
from app.crud.prediction import create_prediction_objects, save_predictions
711
from app.services.prediction_service import PredictionService
812
import logging
@@ -122,3 +126,126 @@ def batch_predict_and_save(start_id: int, end_id: int, max_workers: int = 2):
122126
logger.exception(f"[{stock_id}] 처리 중 예외 발생: {str(e)}")
123127

124128
logger.info("전체 배치 완료")
129+
130+
@staticmethod
131+
def update_ai_avg_increase_and_rank(base_date: datetime.date | None = None):
132+
"""
133+
종목별 예측값 기반 실제 존재하는 예측일 기준 15일 평균 상승률 계산 후 stock_detail 업데이트
134+
:param base_date: 기준일 (None이면 오늘 날짜 사용)
135+
:return: None
136+
"""
137+
db = SessionLocal()
138+
try:
139+
today = base_date or datetime.date.today()
140+
logger.info(f"[BatchService] 평균 상승률 계산 시작 (기준일: {today})")
141+
142+
# --- 예측 데이터 존재 여부 확인 ---
143+
total_preds = db.query(Prediction).count()
144+
future_preds = db.query(Prediction).filter(Prediction.target_date > today).count()
145+
logger.info(f"[BatchService] Prediction 전체: {total_preds}건 / 기준일 이후: {future_preds}건")
146+
147+
if future_preds == 0:
148+
logger.warning(f"[BatchService] 기준일({today}) 이후 예측 데이터가 없습니다.")
149+
return
150+
151+
# --- 첫 번째 예측일 구하기 ---
152+
first_date_subq = (
153+
db.query(
154+
Prediction.stock_id,
155+
func.min(Prediction.target_date).label("first_target_date")
156+
)
157+
.filter(Prediction.target_date > today)
158+
.group_by(Prediction.stock_id)
159+
.subquery()
160+
)
161+
logger.debug(f"[BatchService] first_date_subq 생성 완료")
162+
163+
# --- 첫 예측일의 예측 종가 구하기 ---
164+
subquery_first_price = (
165+
db.query(
166+
Prediction.stock_id,
167+
Prediction.predicted_close.label("first_predicted_close")
168+
)
169+
.join(
170+
first_date_subq,
171+
(Prediction.stock_id == first_date_subq.c.stock_id)
172+
& (Prediction.target_date == first_date_subq.c.first_target_date)
173+
)
174+
.subquery()
175+
)
176+
first_price_count = db.query(subquery_first_price).count()
177+
logger.info(f"[BatchService] 첫 예측일 종가 매핑 완료 ({first_price_count}건)")
178+
179+
# --- 평균 상승률 계산 ---
180+
results = (
181+
db.query(
182+
Prediction.stock_id,
183+
func.avg(
184+
(Prediction.predicted_close - subquery_first_price.c.first_predicted_close)
185+
/ func.nullif(subquery_first_price.c.first_predicted_close, 0.0)
186+
).label("avg_increase")
187+
)
188+
.join(subquery_first_price, Prediction.stock_id == subquery_first_price.c.stock_id)
189+
.filter(Prediction.target_date > today)
190+
.group_by(Prediction.stock_id)
191+
.all()
192+
)
193+
194+
logger.info(f"[BatchService] 평균 상승률 계산 결과: {len(results)}건")
195+
196+
if not results:
197+
logger.warning("[BatchService] 평균 상승률 계산 결과 없음 (JOIN 또는 데이터 매칭 문제 가능)")
198+
return
199+
200+
# --- 상위 3개 샘플 출력 ---
201+
sample_logs = [
202+
f"stock_id={r.stock_id}, avg_increase={round((r.avg_increase or 0) * 100, 2)}%"
203+
for r in results[:3]
204+
]
205+
logger.debug(f"[BatchService] 계산 결과 샘플: {sample_logs}")
206+
207+
# --- 평균 상승률 내림차순 정렬 및 랭킹 부여 ---
208+
sorted_results = sorted(results, key=lambda r: r.avg_increase or 0, reverse=True)
209+
210+
for rank, row in enumerate(sorted_results, start=1):
211+
db.query(StockDetail).filter(StockDetail.stock_id == row.stock_id).update(
212+
{
213+
StockDetail.ai_avg_increase: row.avg_increase,
214+
StockDetail.ai_rank: rank,
215+
StockDetail.updated_at: datetime.datetime.utcnow(),
216+
}
217+
)
218+
if rank <= 3: # 상위 3개만 출력
219+
logger.debug(
220+
f"[BatchService] UPDATE → stock_id={row.stock_id}, "
221+
f"avg_increase={row.avg_increase}, rank={rank}"
222+
)
223+
224+
db.commit()
225+
logger.info(f"[BatchService] 평균 상승률 및 랭킹 갱신 완료 ({len(sorted_results)}개 종목)")
226+
227+
# 상위 3개 종목 조회
228+
top3 = (
229+
db.query(Stock.name, StockDetail.ai_avg_increase)
230+
.join(StockDetail, Stock.id == StockDetail.stock_id)
231+
.order_by(StockDetail.ai_rank.asc())
232+
.limit(3)
233+
.all()
234+
)
235+
236+
# 디스코드 알림용 리스트 반환
237+
top3_info = [
238+
{"name": name, "increase": round((increase or 0) * 100, 2)}
239+
for name, increase in top3
240+
]
241+
242+
return top3_info
243+
244+
except Exception as e:
245+
db.rollback()
246+
logger.exception(f"[BatchService] 평균 상승률 계산 중 오류 발생: {e}")
247+
raise
248+
249+
finally:
250+
db.close()
251+
logger.debug("[BatchService] 세션 종료 완료")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import datetime
3+
from dotenv import load_dotenv
4+
import logging
5+
6+
# --- FastAPI와 동일한 로깅 설정 불러오기 ---
7+
from app.core.logging_config import setup_logging
8+
setup_logging()
9+
10+
# --- 환경 변수 로드 ---
11+
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
12+
dotenv_path = os.path.join(BASE_DIR, ".env")
13+
load_dotenv(dotenv_path)
14+
15+
from app.db.session import SessionLocal
16+
from app.services.batch_service import BatchService
17+
18+
19+
def test_update_ai_avg_increase_and_rank_real_db():
20+
"""
21+
[통합테스트] 실제 DB 기준으로 평균 상승률 계산 및 stock_detail 업데이트 테스트
22+
"""
23+
db = SessionLocal()
24+
25+
# 테스트용 기준일 (직접 지정 가능)
26+
base_date = datetime.date(2025, 10, 13)
27+
28+
print(f"[TEST] 기준일: {base_date}")
29+
30+
try:
31+
BatchService.update_ai_avg_increase_and_rank(base_date)
32+
db.commit()
33+
print("[SUCCESS] 평균 상승률 갱신 완료")
34+
35+
except Exception as e:
36+
db.rollback()
37+
print(f"[ERROR] 테스트 중 예외 발생: {e}")
38+
raise
39+
40+
finally:
41+
db.close()

0 commit comments

Comments
 (0)