From 9431ffc04a95da3157a561fcb97a232cf33a3cd1 Mon Sep 17 00:00:00 2001 From: Gumraze-git Date: Sun, 20 Jul 2025 00:05:07 +0900 Subject: [PATCH] =?UTF-8?q?[CHORE]=20=EA=B0=90=EC=A0=95=20=EC=98=88?= =?UTF-8?q?=EC=B8=A1=20=EC=97=94=EB=93=9C=ED=8F=AC=EC=9D=B8=ED=8A=B8=20?= =?UTF-8?q?=EC=A0=95=EB=A6=AC=20=EB=B0=8F=20LABELS=20=EC=83=81=EC=88=98=20?= =?UTF-8?q?=EC=A0=95=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `/full` 및 `/split_avg` 엔드포인트 제거 - 사용하지 않는 LABELS_6 삭제 - 전체 코드 레이아웃 및 import 순서 정리 --- config.py | 4 ++-- const.py | 6 ------ main.py | 1 + routers/home.py | 1 + routers/predict.py | 34 +++------------------------------- routers/review.py | 5 ++--- schemas.py | 1 + services/prediction.py | 5 +++-- 8 files changed, 13 insertions(+), 44 deletions(-) diff --git a/config.py b/config.py index b08fd4f..33251db 100644 --- a/config.py +++ b/config.py @@ -1,9 +1,9 @@ import os + from pydantic_settings import BaseSettings, SettingsConfigDict + class Settings(BaseSettings): - # mongodb_uri: str = "mongodb://localhost:27017" - # db_name: str = "emotion_db" collection_name: str = "predictions" model_dir: str = os.getenv("MODEL_DIR", "models/0717_kobert_5_emotion_model") title: str = "MovieMood - KoBERT Emotion API" diff --git a/const.py b/const.py index 42d308f..4b8ba1a 100644 --- a/const.py +++ b/const.py @@ -1,11 +1,5 @@ # 7 가지 감정 상태 # 기쁨, 슬픔, 분노, 놀람, 혐오, 공포, 중립 -LABELS_6 = [ - "anger", "fear", "joy", "neutral", "sadness", "surprise" -] -# LABELS_5 = [ -# "anger", "fear", "joy", "neutral", "sadness" -# ] LABELS_5 = [ "anger", "disgust", "fear", "joy", "sadness" diff --git a/main.py b/main.py index 8ecb9eb..1071103 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import uvicorn from fastapi import FastAPI + from config import settings from routers.home import router as home_router from routers.predict import router as predict_router diff --git a/routers/home.py b/routers/home.py index 23f95c5..2727bae 100644 --- a/routers/home.py +++ b/routers/home.py @@ -1,4 +1,5 @@ from fastapi import APIRouter + from config import settings router = APIRouter() diff --git a/routers/predict.py b/routers/predict.py index 480a253..aa7f2dd 100644 --- a/routers/predict.py +++ b/routers/predict.py @@ -1,41 +1,14 @@ -from fastapi import APIRouter, HTTPException from datetime import datetime + +from fastapi import APIRouter, HTTPException + from schemas import TextItem, Prediction from services.prediction import ( - predict_emotion_full, - predict_emotion_split_avg, predict_emotion_overall_avg, ) router = APIRouter(prefix="/predict", tags=["Prediction"]) -@router.post("/full", response_model=Prediction) -async def predict_full(item: TextItem): - try: - probs = predict_emotion_full(item.text) - record = { - "text": item.text, - "probabilities": probs, - "timestamp": datetime.utcnow() - } - # await collection.insert_one(record) - return record - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@router.post("/split_avg", response_model=Prediction) -async def predict_split_avg(item: TextItem): - try: - probs = predict_emotion_split_avg(item.text) - record = { - "text": item.text, - "probabilities": probs, - "timestamp": datetime.utcnow() - } - # await collection.insert_one(record) - return record - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) @router.post("/overall_avg", response_model=Prediction) async def predict_overall_avg(item: TextItem): @@ -46,7 +19,6 @@ async def predict_overall_avg(item: TextItem): "probabilities": probs, "timestamp": datetime.utcnow() } - # await collection.insert_one(record) return record except Exception as e: raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/routers/review.py b/routers/review.py index 9898d64..a8d6069 100644 --- a/routers/review.py +++ b/routers/review.py @@ -1,8 +1,7 @@ # TODO: 데이터 저장하기(mongo or mysql) -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - import requests +from fastapi import APIRouter +from pydantic import BaseModel router = APIRouter( prefix = "/review", diff --git a/schemas.py b/schemas.py index 0b84f67..e94eb7b 100644 --- a/schemas.py +++ b/schemas.py @@ -3,6 +3,7 @@ from pydantic import BaseModel + class TextItem(BaseModel): text: str diff --git a/services/prediction.py b/services/prediction.py index df1baa3..3c64668 100644 --- a/services/prediction.py +++ b/services/prediction.py @@ -1,8 +1,9 @@ import torch -from transformers import AutoModelForSequenceClassification from kobert_tokenizer import KoBERTTokenizer -from const import LABELS_5 +from transformers import AutoModelForSequenceClassification + from config import settings +from const import LABELS_5 _tokenizer = KoBERTTokenizer.from_pretrained(settings.model_dir) _model = AutoModelForSequenceClassification.from_pretrained(settings.model_dir)