Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
# mongodb_uri: str = "mongodb://localhost:27017"
# db_name: str = "emotion_db"
collection_name: str = "predictions"
model_dir: str = os.getenv("MODEL_DIR", "models/0717_kobert_5_emotion_model")
title: str = "MovieMood - KoBERT Emotion API"
Expand Down
6 changes: 0 additions & 6 deletions const.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
# 7 가지 감정 상태
# 기쁨, 슬픔, 분노, 놀람, 혐오, 공포, 중립
LABELS_6 = [
"anger", "fear", "joy", "neutral", "sadness", "surprise"
]
# LABELS_5 = [
# "anger", "fear", "joy", "neutral", "sadness"
# ]

LABELS_5 = [
"anger", "disgust", "fear", "joy", "sadness"
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uvicorn
from fastapi import FastAPI

from config import settings
from routers.home import router as home_router
from routers.predict import router as predict_router
Expand Down
1 change: 1 addition & 0 deletions routers/home.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter

from config import settings

router = APIRouter()
Expand Down
34 changes: 3 additions & 31 deletions routers/predict.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,14 @@
from fastapi import APIRouter, HTTPException
from datetime import datetime

from fastapi import APIRouter, HTTPException

from schemas import TextItem, Prediction
from services.prediction import (
predict_emotion_full,
predict_emotion_split_avg,
predict_emotion_overall_avg,
)

router = APIRouter(prefix="/predict", tags=["Prediction"])

@router.post("/full", response_model=Prediction)
async def predict_full(item: TextItem):
try:
probs = predict_emotion_full(item.text)
record = {
"text": item.text,
"probabilities": probs,
"timestamp": datetime.utcnow()
}
# await collection.insert_one(record)
return record
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@router.post("/split_avg", response_model=Prediction)
async def predict_split_avg(item: TextItem):
try:
probs = predict_emotion_split_avg(item.text)
record = {
"text": item.text,
"probabilities": probs,
"timestamp": datetime.utcnow()
}
# await collection.insert_one(record)
return record
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@router.post("/overall_avg", response_model=Prediction)
async def predict_overall_avg(item: TextItem):
Expand All @@ -46,7 +19,6 @@ async def predict_overall_avg(item: TextItem):
"probabilities": probs,
"timestamp": datetime.utcnow()
}
# await collection.insert_one(record)
return record
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
5 changes: 2 additions & 3 deletions routers/review.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# TODO: 데이터 저장하기(mongo or mysql)
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

import requests
from fastapi import APIRouter
from pydantic import BaseModel

router = APIRouter(
prefix = "/review",
Expand Down
1 change: 1 addition & 0 deletions schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel


class TextItem(BaseModel):
text: str

Expand Down
5 changes: 3 additions & 2 deletions services/prediction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from transformers import AutoModelForSequenceClassification
from kobert_tokenizer import KoBERTTokenizer
from const import LABELS_5
from transformers import AutoModelForSequenceClassification

from config import settings
from const import LABELS_5

_tokenizer = KoBERTTokenizer.from_pretrained(settings.model_dir)
_model = AutoModelForSequenceClassification.from_pretrained(settings.model_dir)
Expand Down