Skip to content

Commit c19bf4e

Browse files
authored
Merge pull request #21 from SynergyX-AI-Pattern/feat/#16_stock_search_by_image
2 parents d1d932a + ae057c8 commit c19bf4e

File tree

16 files changed

+340
-3
lines changed

16 files changed

+340
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,4 +361,7 @@ $RECYCLE.BIN/
361361
# Windows shortcuts
362362
*.lnk
363363

364+
# API Keys
365+
keys/
366+
364367
# End of https://www.toptal.com/developers/gitignore/api/pycharm,python,venv,windows,macos,virtualenv

app/api/v1/endpoints/stocks/__init__.py

Whitespace-only changes.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from fastapi import APIRouter, UploadFile, File, Depends
2+
from app.api_payload.code.error_status import ErrorStatus
3+
from app.db.session import get_db
4+
from sqlalchemy.orm import Session
5+
from app.api_payload.code.success_status import SuccessStatus
6+
from app.core.response import success_response
7+
from app.exceptions.base import APIException
8+
from app.schemas.base_response import BaseResponse
9+
from app.schemas.image_search import ImageSearchResponse
10+
from app.services.image_search_service import ImageSearchService
11+
12+
router = APIRouter()
13+
14+
@router.post(
15+
"/search-by-image",
16+
response_model=BaseResponse,
17+
summary="이미지 기반 종목 검색",
18+
description="""
19+
업로드된 이미지로부터 브랜드를 인식하여 종목 정보를 반환합니다.
20+
최대 5MB 파일만 업로드 가능합니다.
21+
""",
22+
tags=["Stock"]
23+
)
24+
async def search_stock_by_image(
25+
image: UploadFile = File(...),
26+
db: Session = Depends(get_db)
27+
):
28+
contents = await image.read()
29+
MAX_IMAGE_SIZE = 5 * 1024 * 1024 # 5MB
30+
31+
# 파일 크기 제한(5MB)
32+
if len(contents) > MAX_IMAGE_SIZE:
33+
raise APIException(ErrorStatus.FILE_TOO_LARGE)
34+
35+
result: ImageSearchResponse = ImageSearchService.search_stock_by_image(contents, db)
36+
37+
return success_response(
38+
data=result,
39+
status=SuccessStatus.IMAGE_SEARCH_SUCCESS
40+
)

app/api/v1/routers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from fastapi import APIRouter
2-
from app.api.v1.routers import stock_router, backtest_router, pattern_detection_router
2+
from app.api.v1.routers import stock_router, backtest_router, pattern_detection_router, emotion_diary_router
33

44
router = APIRouter()
55

66
router.include_router(stock_router.router, prefix="/v1")
77
router.include_router(backtest_router.router, prefix="/v1")
88
router.include_router(pattern_detection_router.router, prefix="/v1")
9+
router.include_router(emotion_diary_router.router, prefix="/v1")
910

1011
# 아래에 추가 ex. router.include_router(user_router.router, prefix="/v1")

app/api/v1/routers/stock_router.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi import APIRouter
2-
from app.api.v1.endpoints import stock
2+
from app.api.v1.endpoints.stocks import stock, image_search
33

44
router = APIRouter()
55
router.include_router(stock.router, prefix="/stocks", tags=["Stock"])
6+
router.include_router(image_search.router, prefix="/stocks", tags=["Stock"])

app/api_payload/code/error_status.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class ErrorStatus(Enum):
77
VALIDATION_ERROR = ("VALIDATION_ERROR", "유효성 검사 실패", status.HTTP_400_BAD_REQUEST)
88
MODEL_ERROR = ("MODEL_ERROR", "모델 예측 중 오류 발생", status.HTTP_500_INTERNAL_SERVER_ERROR)
99
NOTIFICATION_SEND_FAILED = ("NOTIFICATION_SEND_FAILED", "알림 전송에 실패했습니다.", status.HTTP_500_INTERNAL_SERVER_ERROR)
10+
GPT_API_ERROR = ("GPT_API_ERROR", "GPT API 호출에 실패했습니다.", status.HTTP_502_BAD_GATEWAY)
11+
FILE_TOO_LARGE = ("FILE_TOO_LARGE", "업로드한 파일의 크기가 너무 큽니다.", status.HTTP_413_REQUEST_ENTITY_TOO_LARGE)
12+
GPT_RESPONSE_PARSE_ERROR = ("GPT_RESPONSE_PARSE_ERROR", "GPT 응답을 JSON으로 파싱할 수 없습니다.", status.HTTP_500_INTERNAL_SERVER_ERROR)
1013

1114
# 종목
1215
STOCK_NOT_FOUND = ("STOCK_NOT_FOUND", "해당 종목을 찾을 수 없습니다.", status.HTTP_404_NOT_FOUND)

app/api_payload/code/success_status.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class SuccessStatus(Enum):
99
BACKTEST_EXECUTED = ("BACKTEST_EXECUTED", "백테스팅 실행 성공", status.HTTP_200_OK)
1010
STOCK_PREDICTED = ("STOCK200", "15일 예측 성공", status.HTTP_200_OK)
1111
PATTERN_DETECTION_EXECUTED = ("PATTERN_DETECTION_EXECUTED", "패턴 감지 성공", status.HTTP_200_OK)
12+
IMAGE_SEARCH_SUCCESS = ("IMAGE_SEARCH_SUCCESS", "이미지로 종목 검색 성공", status.HTTP_200_OK)
13+
DIARY_ANALYSIS_SUCCESS = ("DIARY_ANALYSIS_SUCCESS", "감정 분석 성공", status.HTTP_200_OK)
1214

1315
def __init__(self, code, message, http_status):
1416
self.code = code

app/core/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from pydantic_settings import BaseSettings
22
from typing import Literal
3-
from pydantic import HttpUrl
3+
from pydantic import HttpUrl, SecretStr
44

55
class Settings(BaseSettings):
66
DATABASE_URL: str
77
ENV: Literal["local", "prod"] = "prod" # 기본값 prod
88
DISCORD_WEBHOOK_URL: str
99
SPRING_SERVER_BASE_URL: HttpUrl
10+
GOOGLE_APPLICATION_CREDENTIALS: str
11+
OPENAI_API_KEY: SecretStr
1012

1113
class Config:
1214
env_file = ".env"

app/crud/stock.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from difflib import get_close_matches
2+
from sqlalchemy import func
13
from sqlalchemy.orm import Session
24
from app.models.stock import Stock
35

@@ -7,3 +9,33 @@ def get_stock_by_id(db: Session, stock_id: int) -> Stock | None:
79
stock_id 기준으로 종목 정보를 조회합니다.
810
"""
911
return db.query(Stock).filter(Stock.id == stock_id).first()
12+
13+
def get_stock_by_name(db: Session, name: str) -> Stock | None:
14+
"""
15+
종목명 기준으로 종목 정보를 조회합니다.
16+
"""
17+
18+
# 입력 정규화
19+
# 공백 제거 및 대소문자 무시
20+
normalized_name = name.replace(" ", "").lower()
21+
22+
# 1차 : 정확히 일치할 경우 반환
23+
stock = (
24+
db.query(Stock)
25+
.filter(func.replace(func.lower(Stock.name), " ", "") == normalized_name)
26+
.first()
27+
)
28+
if stock:
29+
return stock
30+
31+
# 2차 : 유사도 기반 매칭 결과 반환
32+
all_stock_names = db.query(Stock.name).all()
33+
name_list = [str(row[0]) for row in all_stock_names]
34+
close_matches = get_close_matches(name, name_list, n=1, cutoff=0.6) # 유사도 60% 이상
35+
36+
if close_matches:
37+
return db.query(Stock).filter(Stock.name == close_matches[0]).first()
38+
39+
# 없을 경우 None 리턴
40+
return None
41+

0 commit comments

Comments
 (0)