Skip to content

Commit f53aeea

Browse files
authored
Merge pull request #10 from 2025-Graduation-Design/new
new -> main
2 parents b685952 + 3c7a31e commit f53aeea

11 files changed

Lines changed: 138 additions & 65 deletions

File tree

app/diary/models.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import Column, Integer, ForeignKey, Text, Float, DateTime, JSON
1+
from sqlalchemy import Column, Integer, ForeignKey, Text, Float, DateTime, JSON, String
22
from sqlalchemy.orm import relationship
33
from datetime import datetime
44
from app.database import Base
@@ -16,14 +16,31 @@ class Diary(Base):
1616
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
1717
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
1818

19-
# ✅ 클래스명을 정확히 사용 (User, EmotionType)
2019
user = relationship("User", back_populates="diaries")
2120
emotion = relationship(EmotionType, back_populates="diaries")
21+
recommended_songs = relationship("RecommendedSong", back_populates="diary", cascade="all, delete-orphan")
2222

2323

2424
class diaryEmbedding(Base):
2525
__tablename__ = "diaryEmbedding"
2626

2727
id = Column(Integer, primary_key=True, autoincrement=True)
2828
diary_id = Column(Integer, nullable=False)
29-
embedding = Column(JSON, nullable=False)
29+
embedding = Column(JSON, nullable=False)
30+
31+
32+
class RecommendedSong(Base):
33+
__tablename__ = "recommendedSongs"
34+
35+
id = Column(Integer, primary_key=True)
36+
diary_id = Column(Integer, ForeignKey('diary.id'), nullable=False)
37+
song_id = Column(Integer, nullable=False)
38+
song_name = Column(String(256))
39+
artist = Column(JSON)
40+
genre = Column(String(64))
41+
album_image = Column(String(512))
42+
best_lyric = Column(Text)
43+
similarity_score = Column(Float)
44+
created_at = Column(DateTime, default=datetime.utcnow)
45+
46+
diary = relationship("Diary", back_populates="recommended_songs")

app/diary/router.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from app.emotion.router import predict_emotion
77
from app.statistics.models import EmotionStatistics
88
from app.user.auth import get_current_user
9-
from app.diary.models import Diary
10-
from app.diary.schemas import DiaryCreateRequest, DiaryUpdateRequest, DiaryResponse, DiaryCountResponse
9+
from app.diary.models import Diary, RecommendedSong
10+
from app.diary.schemas import DiaryCreateRequest, DiaryUpdateRequest, DiaryResponse, DiaryCountResponse, SongResponse
1111
from app.user.models import User
1212
from app.embedding.models import kobert, save_diary_embedding, split_sentences, get_user_preferred_genres, \
1313
get_songs_by_genre, get_song_embeddings, calculate_similarity
@@ -18,15 +18,13 @@
1818
import torch
1919
import numpy as np
2020
import heapq
21-
import torch.nn.functional as F
2221
from datetime import datetime
2322

2423
router = APIRouter()
2524

2625
logging.basicConfig(level=logging.INFO)
2726
logger = logging.getLogger(__name__)
2827

29-
# 📝 일기 작성 API
3028
@router.post("", response_model=DiaryResponse, status_code=201, summary="일기 작성 & 노래 추천",
3129
description="일기를 작성하면 자동으로 임베딩을 진행하고, 사용자의 선호 장르 내에서 가장 유사한 노래를 추천합니다.")
3230
async def create_diary(
@@ -305,6 +303,32 @@ async def create_diary_with_music_recommend_top3(
305303

306304
save_diary_embedding(session, new_diary.id, combined_embedding)
307305

306+
recommended_songs = [
307+
{
308+
"song_id": match["song_id"],
309+
"song_name": match["metadata"]["song_name"],
310+
"best_lyric": " ".join(match["lyric_chunk"]),
311+
"similarity_score": round(float(sim), 4),
312+
"album_image": match["metadata"]["album_image"],
313+
"artist": match["metadata"]["artist"],
314+
"genre": match["metadata"]["genre"]
315+
}
316+
for sim, match in top_3
317+
]
318+
319+
for song_data in recommended_songs:
320+
new_song = RecommendedSong(
321+
diary_id=new_diary.id,
322+
song_id=song_data["song_id"], # MongoDB ID 문자열 변환
323+
song_name=song_data["song_name"],
324+
artist=song_data["artist"],
325+
genre=song_data["genre"],
326+
album_image=song_data["album_image"],
327+
best_lyric=song_data["best_lyric"],
328+
similarity_score=song_data["similarity_score"]
329+
)
330+
session.add(new_song)
331+
308332
# 9-1) 감정 통계 업데이트 또는 추가
309333
existing_stat = session.query(EmotionStatistics).filter(
310334
EmotionStatistics.user_id == current_user.id,
@@ -327,20 +351,9 @@ async def create_diary_with_music_recommend_top3(
327351
)
328352
session.add(new_stat)
329353

330-
# 10) 응답 구성
331-
recommended_songs = [
332-
{
333-
"song_id": match["song_id"],
334-
"song_name": match["metadata"]["song_name"],
335-
"best_lyric": " ".join(match["lyric_chunk"]),
336-
"similarity_score": round(float(sim), 4),
337-
"album_image": match["metadata"]["album_image"],
338-
"artist": match["metadata"]["artist"],
339-
"genre": match["metadata"]["genre"]
340-
}
341-
for sim, match in top_3
342-
]
354+
session.commit()
343355

356+
# 10) 응답 구성
344357
response_data = {
345358
"id": new_diary.id,
346359
"user_id": new_diary.user_id,
@@ -349,7 +362,9 @@ async def create_diary_with_music_recommend_top3(
349362
"confidence": confidence_full,
350363
"created_at": new_diary.created_at,
351364
"updated_at": new_diary.updated_at,
352-
"recommended_songs": recommended_songs
365+
"recommended_songs": recommended_songs,
366+
"top_emotions": [{"emotion_id": emo_id, "score": round(score, 4)}
367+
for emo_id, score in sorted(emotion_vote_counter.items(), key=lambda x: -x[1])[:3]]
353368
}
354369

355370
logger.info("추천 결과: %s", json.dumps(response_data, indent=2, ensure_ascii=False, default=str))
@@ -374,6 +389,33 @@ def get_diary(
374389

375390
return diary
376391

392+
@router.get("/{diary_id}/recommended-songs", response_model=list[SongResponse],
393+
summary="추천 노래 조회",
394+
description="특정 일기에 대한 추천 노래 리스트를 조회합니다.")
395+
def get_recommended_songs_by_diary(
396+
diary_id: int,
397+
current_user: User = Depends(get_current_user),
398+
db: Session = Depends(get_db)
399+
):
400+
# 1. 해당 일기가 유저의 것인지 검증
401+
diary = db.query(Diary).filter(
402+
Diary.id == diary_id,
403+
Diary.user_id == current_user.id
404+
).first()
405+
406+
if not diary:
407+
raise HTTPException(status_code=404, detail="일기를 찾을 수 없습니다.")
408+
409+
# 2. 추천곡 조회
410+
songs = db.query(RecommendedSong).filter(
411+
RecommendedSong.diary_id == diary_id
412+
).order_by(RecommendedSong.similarity_score.desc()).all()
413+
414+
if not songs:
415+
raise HTTPException(status_code=404, detail="추천된 노래가 없습니다.")
416+
417+
return songs
418+
377419
@router.get("", response_model=List[DiaryResponse],
378420
summary="내 일기 목록 조회",
379421
description="로그인한 사용자가 작성한 모든 일기를 조회합니다.")

app/diary/schemas.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from datetime import datetime
33
from typing import Optional
44

5+
class EmotionScore(BaseModel):
6+
emotion_id: int
7+
score: float
8+
59
# Request
610
class DiaryCreateRequest(BaseModel):
711
content: str
@@ -10,12 +14,26 @@ class DiaryUpdateRequest(BaseModel):
1014
content: Optional[str] = None
1115

1216
# Response
17+
class SongResponse(BaseModel):
18+
song_id: int # MongoDB _id → 문자열 변환
19+
song_name: str
20+
artist: list[str]
21+
genre: str
22+
album_image: str
23+
best_lyric: str
24+
similarity_score: float
25+
26+
class Config:
27+
allow_population_by_field_name = True
28+
1329
class DiaryResponse(BaseModel):
1430
id: int
1531
user_id: int
1632
content: str
1733
emotiontype_id: Optional[int] = None
1834
confidence: Optional[float] = None
35+
recommended_songs: list[SongResponse]
36+
top_emotions: list[EmotionScore] = None
1937
created_at: datetime
2038
updated_at: datetime
2139

app/embedding/router.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,22 @@ async def embed_songs(
4242
try:
4343
embeddings = [
4444
kobert.get_embedding(line)
45-
for line in lyrics if line.strip()
45+
for line in lyrics
46+
if line.strip() and len(line.strip()) >= 15
4647
]
48+
49+
if not embeddings:
50+
logger.info(f"스킵됨 (song_id={song_id}): 임베딩 가능한 가사가 없음")
51+
continue
52+
4753
save_song_embedding(db, song_id, embeddings)
4854
processed_songs.append({"song_id": song_id, "status": "embedded"})
4955
except Exception as e:
5056
logger.error(f"임베딩 실패 (song_id={song_id}): {e}")
5157
continue
5258

5359
db.commit()
54-
time.sleep(3) # 한 배치 끝나고 쿨다운
60+
time.sleep(2) # 한 배치 끝나고 쿨다운
5561

5662
return {
5763
"total_songs": len(songs),

app/emotion/models.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
2-
import torch.nn as nn
3-
from transformers import BertModel, BertTokenizer, AutoTokenizer, AutoModel
2+
from transformers import BertForSequenceClassification
3+
from kobert_tokenizer import KoBERTTokenizer
44

55
from sqlalchemy import Column, Integer, ForeignKey, String
66
from sqlalchemy.orm import relationship
@@ -30,24 +30,9 @@ class EmotionType(Base):
3030
quadrant = Column(Integer, nullable=False)
3131
related_emotion_id = Column(Integer, ForeignKey("emotionType.id", ondelete="SET NULL"), nullable=True)
3232

33-
# ✅ 관계 명칭을 `diaries`로 변경
3433
diaries = relationship("Diary", back_populates="emotion")
3534

36-
class EmotionClassifier(nn.Module):
37-
def __init__(self, num_classes=8):
38-
super(EmotionClassifier, self).__init__()
39-
self.bert = AutoModel.from_pretrained("skt/kobert-base-v1")
40-
self.dropout = nn.Dropout(0.1)
41-
self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
42-
43-
def forward(self, input_ids, attention_mask):
44-
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
45-
pooled = outputs.pooler_output
46-
return self.classifier(self.dropout(pooled))
47-
48-
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
49-
50-
model = EmotionClassifier(num_classes=len(emotion_labels))
51-
state_dict = torch.load("app/emotion/best_model(1st).pt", map_location=torch.device("cpu"))
52-
model.load_state_dict(state_dict)
35+
tokenizer = KoBERTTokenizer.from_pretrained("skt/kobert-base-v1")
36+
model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=8)
37+
model.load_state_dict(torch.load("app/emotion/best_model(7th).pt", map_location="cpu"))
5338
model.eval()

app/emotion/router.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44

55
logger = logging.getLogger(__name__)
66

7-
87
def predict_emotion(text: str):
9-
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
8+
inputs = tokenizer(
9+
text,
10+
return_tensors='pt',
11+
truncation=True,
12+
padding=True,
13+
max_length=256
14+
)
1015
with torch.no_grad():
1116
outputs = model(
1217
input_ids=inputs["input_ids"],
1318
attention_mask=inputs["attention_mask"]
1419
)
15-
# outputs가 Tensor일 경우 그냥 사용
16-
if isinstance(outputs, tuple) or isinstance(outputs, list):
17-
logits = outputs[0]
18-
else:
19-
logits = outputs
2020

21+
logits = outputs.logits
2122
probs = torch.softmax(logits, dim=1)
2223
pred_index = torch.argmax(probs, dim=1).item()
2324

app/emotion/schemas.py

Whitespace-only changes.

app/genre/router.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ def get_all_genres(db: Session = Depends(get_db)):
1313

1414
@router.post("/new", summary="장르 추가", description="현재 DB에 있는 음악 장르를 추가합니다.")
1515
async def add_genres_from_mongodb(mongodb=Depends(get_mongodb), db: Session = Depends(get_db)):
16-
"""
17-
MongoDB에서 'genre' 컬럼을 가져와 MySQL의 Genre 테이블에 추가
18-
(중복 제거 후 추가)
19-
"""
2016
mongo_genres = await mongodb["song_meta"].distinct("genre")
2117
genre_set = set()
2218

app/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from fastapi import FastAPI
2+
from fastapi.middleware.cors import CORSMiddleware
23

34
from app.user.router import router as user_router
45
from app.diary.router import router as diary_router
@@ -18,6 +19,14 @@
1819
app.include_router(embedding_router, prefix="/embedding", tags=["embedding"])
1920
app.include_router(statistics_router, prefix="/statistics", tags=["statistics"])
2021

22+
app.add_middleware(
23+
CORSMiddleware,
24+
allow_origins=["*"], # 모든 도메인 허용 (개발용)
25+
allow_credentials=True,
26+
allow_methods=["*"],
27+
allow_headers=["*"],
28+
)
29+
2130
@app.get("/")
2231
def read_root():
23-
return {"message": "Hello, Melog API!"}
32+
return {"message": "멜로그 시작 화면에 뜰 메세지에요~"}

app/transaction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ def transactional_session(db: Session):
77
Spring Boot @Transactional처럼 사용할 수 있는 컨텍스트 매니저
88
"""
99
try:
10-
yield db # ✅ 트랜잭션 시작
11-
db.commit() # ✅ 정상 실행 시 커밋
10+
yield db
11+
db.commit()
1212
except Exception as e:
13-
db.rollback() # 🚨 예외 발생 시 롤백
13+
db.rollback()
1414
raise e
1515
finally:
16-
db.close() # ✅ 세션 닫기
16+
db.close()

0 commit comments

Comments
 (0)