Skip to content

Commit d6df3de

Browse files
authored
Merge pull request #46 from SynergyX-AI-Pattern/refactor/#44_pattern_detection_accuracy
Refactor/#44 pattern detection accuracy
2 parents 947df0a + 3fbd984 commit d6df3de

File tree

1 file changed

+74
-13
lines changed

1 file changed

+74
-13
lines changed

app/services/pattern_detection_service.py

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import logging
2+
import pandas as pd
23
from datetime import datetime, timedelta, timezone
34
from typing import Optional, Tuple
4-
55
from sqlalchemy.orm import Session
6-
76
from app.api_payload.code.error_status import ErrorStatus
87
from app.exceptions.base import APIException
98
from app.crud.stock_timeseries import get_stock_timeseries_by_unit
@@ -33,16 +32,23 @@ def detect(db: Session) -> list[dict]:
3332

3433
# 감지 대상 패턴 불러오기
3534
applies = get_applies(db)
35+
if not applies:
36+
logger.info("[PatternDetection] 감지 대상 없음")
37+
return []
3638

3739
# 감지 성공한 패턴
3840
success_applies = []
3941

4042
for apply in applies:
4143
# 각 패턴-종목에 대해 감지 수행
42-
result = PatternDetectionService._process_apply(apply, db, now)
43-
if result:
44-
success_applies.append(apply)
45-
results.append(result)
44+
try:
45+
result = PatternDetectionService._process_apply(apply, db, now)
46+
if result:
47+
success_applies.append(apply)
48+
results.append(result)
49+
except Exception as e:
50+
logger.warning(f"[PatternDetection] {apply.stock.name} 감지 중 오류 발생: {e}")
51+
continue
4652

4753
# 감지 성공한 패턴은 알림 설정 해제
4854
for apply in success_applies:
@@ -88,7 +94,8 @@ def _process_apply(apply, db: Session, now: datetime) -> Optional[dict]:
8894
now=now
8995
)
9096

91-
if not closes or not timestamps:
97+
# 데이터 유효성 검증
98+
if not closes or not timestamps or len(closes) < len(pattern) * 2:
9299
return None
93100

94101
# DTW 매칭
@@ -97,19 +104,26 @@ def _process_apply(apply, db: Session, now: datetime) -> Optional[dict]:
97104
except APIException as e:
98105
# 데이터 부족 시 감지 생략
99106
if e.status == ErrorStatus.NOT_ENOUGH_DATA:
100-
raise APIException(ErrorStatus.NOT_ENOUGH_DATA) from e
101-
raise
107+
return None
108+
return None
102109

103-
# 매칭된 구간 없으면 감지 안 함
110+
# 방향성 검증
111+
idxes = [
112+
i for i in idxes
113+
if PatternDetectionService._same_direction(pattern, closes, i)
114+
]
115+
116+
# 매칭된 구간 없을 시 감지 생략
104117
if not idxes:
105118
return None
106119

107120
# 감지 시점 종가
108121
current_price = closes[-1]
109122

110123
# 수익률 계산
111-
rate_of_return = ((current_price - entry_price) / entry_price) * 100
124+
rate_of_return = round(((current_price - entry_price) / entry_price) * 100, 2)
112125

126+
# 최소 수익률 조건 미충족 시 감지 생략
113127
if min_valid_return is not None and rate_of_return < min_valid_return:
114128
return None
115129

@@ -153,6 +167,25 @@ def _process_apply(apply, db: Session, now: datetime) -> Optional[dict]:
153167
"value": value
154168
}
155169

170+
@staticmethod
171+
def _same_direction(
172+
pattern: list[float],
173+
closes: list[float],
174+
idx: int
175+
) -> bool:
176+
"""
177+
패턴과 실제 주가의 방향 (상승, 하락)이 일치하는지 검증합니다.
178+
"""
179+
180+
# 패턴의 방향 (기울기)
181+
pat_slope = pattern[-1] - pattern[0]
182+
183+
# 주가의 방향 (기울기)
184+
seg_slope = closes[idx + len(pattern) - 1] - closes[idx]
185+
186+
# 부호가 동일하면 동일 방향
187+
return (pat_slope * seg_slope) > 0
188+
156189
@staticmethod
157190
def _load_price_data(
158191
db: Session,
@@ -162,7 +195,8 @@ def _load_price_data(
162195
now: datetime
163196
) -> Tuple[list[float], list[datetime]]:
164197
"""
165-
진입 시점(entry_at)부터 현재까지의 가격 데이터를 조회합니다.
198+
진입 시점(entry_at)부터 현재까지의 가격 데이터를 조회하고,
199+
노이즈를 제거합니다.
166200
167201
Returns:
168202
- 종가 리스트 (closes)
@@ -180,4 +214,31 @@ def _load_price_data(
180214
raise APIException(ErrorStatus.STOCK_OHLCV_NOT_FOUND)
181215

182216
timestamps, closes = zip(*rows)
183-
return list(closes), list(timestamps)
217+
218+
# 데이터 충분 시 smoothing 기법 사용
219+
if len(closes) > 3:
220+
closes = PatternDetectionService._smooth_series(closes, window=3)
221+
return list(closes), list(timestamps)
222+
223+
@staticmethod
224+
def _smooth_series(
225+
closes: list[float],
226+
window: int = 3
227+
) -> list[float]:
228+
"""
229+
이동 평균선을 사용하여 노이즈를 제거합니다.
230+
231+
Parameters:
232+
closes: 종가 리스트
233+
window: 구간 길이 (기본값: 3)
234+
235+
Returns:
236+
노이즈가 제거된 종가 리스트
237+
"""
238+
239+
# 기존 종가 리스트 변환
240+
series = pd.Series(closes)
241+
242+
# 이동 평균 계산 및 보정
243+
smoothed = series.rolling(window=window, center=True).mean().bfill().ffill()
244+
return smoothed.tolist()

0 commit comments

Comments
 (0)