-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
291 lines (251 loc) · 9.95 KB
/
main.py
File metadata and controls
291 lines (251 loc) · 9.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
from fastapi import FastAPI, HTTPException, Path
from finance.stock_api import get_stock_info, format_currency, format_volume
from app.DB import DB
from app.neo4jrag import Neo4jRAG
from dotenv import load_dotenv
from datetime import datetime, timedelta
from pykrx import stock
from pydantic import BaseModel
app = FastAPI()
db = DB() # DB 클래스 인스턴스 생성
# .env 파일 로드
load_dotenv()
# 기본 루트 경로
@app.get("/")
async def root():
return {"message": "AI API 서버가 실행 중입니다."}
# 기업정보 가져오기
@app.get("/companies/stock-info/{company_id}")
async def get_company_stock_info(company_id: str):
# 1. MySQL에서 기업 정보 가져오기 (가정)
company_info = db.get_company_info(company_id)
if not company_info:
raise HTTPException(status_code=404, detail="회사 정보를 찾을 수 없습니다.")
try:
if "ticker" not in company_info or not company_info["ticker"]:
raise HTTPException(
status_code=404, detail="해당 기업의 종목코드(ticker) 정보가 없습니다."
)
# 2. 주식 정보 조회
stock_info = get_stock_info(company_info["ticker"])
# 3. 데이터 변환 적용
formatted_response = {
"company_name": company_info.get("company", "정보 없음"),
"ticker": company_info.get("ticker", "정보 없음"),
"trading_volume": format_volume(stock_info.get("trading_volume", 0)),
"trading_value": format_currency(stock_info.get("trading_value", 0)),
"low_52weeks": format_currency(stock_info.get("low_52weeks", 0)),
"high_52weeks": format_currency(stock_info.get("high_52weeks", 0)),
"change_amount": stock_info.get("change_amount", 0),
"change_percent": stock_info.get("change_percent", 0.0),
}
print("Formatted response:", formatted_response) # FastAPI가 반환하기 전에 확인
print(type(formatted_response))
return formatted_response
except Exception as e:
raise HTTPException(status_code=500, detail=f"서버 오류: {str(e)}")
@app.get("/companies/stock-info/{ticker}/chart/{days}")
async def get_stock_chart(ticker: str, days: int):
try:
# 오늘 날짜 기준으로 시작 날짜 계산
today = datetime.today()
start_date = (today - timedelta(days=days)).strftime("%Y%m%d")
end_date = today.strftime("%Y%m%d")
# Pykrx를 이용하여 OHLCV 데이터 가져오기
df = stock.get_market_ohlcv_by_date(start_date, end_date, ticker)
# DataFrame이 비어있는 경우 예외 처리
if df.empty:
raise HTTPException(
status_code=404, detail="해당 기간 동안 주식 데이터가 없습니다."
)
# JSON 응답 포맷 변경
stock_data = [
{
"x": date.strftime("%Y-%m-%d"),
"o": row["시가"],
"h": row["고가"],
"l": row["저가"],
"c": row["종가"],
}
for date, row in df.iterrows()
]
return {"stockData": stock_data}
except Exception as e:
raise HTTPException(status_code=500, detail=f"서버 오류: {str(e)}")
# 연관 기업
@app.get("/companies/{company_id}/related/")
def get_related_companies(company_id: str):
rag = Neo4jRAG()
db = DB()
company = db.query(
f'SELECT company FROM company WHERE company_id = "{company_id}"'
)[0]["company"]
related_companies_name = rag.get_related_companies(company)["related_companies"]
# [{"company":"SK하이닉스"},{"company":"LG전자"},{"company":"삼성SDI"}]
related_companies = [
{
"company_id": db.query(
f"SELECT company_id FROM company WHERE company = '{company['company']}'"
)[0]["company_id"],
"company": company["company"],
}
for company in related_companies_name
]
db.close()
return {
"company_id": company_id,
"company": company,
"related_companies": related_companies,
}
# 연관 키워드
@app.get("/keywords/{company_id}/related/")
def get_related_companies(company_id: str):
rag = Neo4jRAG()
db = DB()
company = db.query(
f'SELECT company FROM company WHERE company_id = "{company_id}"'
)[0]["company"]
response = rag.get_response_from_prompt(
"""최근 {company} 기업의 주가에 가장 영향이 있었던 단어 10개를 아래와 같은 형식으로 반환해줘. 그 외의 텍스트 출력은 일절 없어야해. 단어는 띄어쓰기 없는 단어를 말해
#형식
{{
"keywords": [
{{
"keyword_id": "id1",
"keyword": "keyword1"
}},
{{
"keyword_id": "id2",
"keyword": "keyword2"
}},
{{
"keyword_id": "id3",
"keyword": "keyword3"
}},
...
],
}}""",
company=company,
)
db.close()
return response
# 관련 뉴스
@app.get("/news/{company_id}")
def get_news(company_id: str):
print('here!')
# rag = Neo4jRAG()
db = DB()
company = db.query(
f"SELECT company FROM company WHERE company_id = '{company_id}'"
)[0]["company"]
sql = f"""SELECT
news_id,
(LENGTH(title) - LENGTH(REPLACE(title, '{company}', ''))) / LENGTH('{company}') +
(LENGTH(article_text) - LENGTH(REPLACE(article_text, '{company}', ''))) / LENGTH('{company}') AS keyword_count
FROM
news
ORDER BY
keyword_count DESC
LIMIT 3;"""
news = db.query(sql)
news_id = [n["news_id"] for n in news]
db.close()
return {"news_id": news_id}
class NewsID(BaseModel):
company_id: str
news_id: list
# 관련 뉴스 - 감성
@app.get('/news/sentiment/{company_id}/{news_ids}')
async def get_sentiment(company_id:str, news_ids:str):
news_id = news_ids.split(',')
rag = Neo4jRAG()
db = DB()
company = db.query(
f"SELECT company FROM company WHERE company_id = '{company_id}'"
)[0]["company"]
sql = f"SELECT news_id, title, sub_title, url, date, article_text FROM news WHERE news_id IN ({', '.join(news_id)})"
news = db.query(sql)
news_text_list = [
f'news_id : {n["news_id"]}\ntitle : {n["title"]}\nsub_title : {n["sub_title"]}\ncontent : {n["article_text"]}' for n in news
]
news_dict = dict([(n['news_id'], n) for n in news])
news_text = "\n\n".join(news_text_list)
prompt_text = """
# 지시문
당신은 10년차 증권사 애널리스트입니다. 특히 {company} 기업의 주가에 관심이 많아 관련 뉴스를 매일 확인하고 분석하여 주가에 미치는 영향을 판단합니다.
아래의 제약조건과 입력문을 토대로 최고의 결과를 출력해주세요. 100점 만점의 결과를 출력해주세요.
# # 제약조건
# 1. 뉴스는 3가지 뉴스가 주어집니다.
# 2. 각 뉴스별로, 해당 뉴스가 {company} 기업의 주가에 미치는 영향을 분석(긍정 또는 부정)을 하여 'sentiment'에 결과를 출력합니다.
# 3. 해당 뉴스를 보고, 기업의 주가에 영향을 많이 미치는 상위 3개의 단어를 키워드로 선정해주세요.
# 4. 아래의 형식에 따라 JSON 형식으로 결과를 출력해 주세요. 그 외 어떠한 텍스트도 출력하지 마세요
# ## 형식
# {{
# "news": [{{
# "news_id": "id1",
# "sentiment": "긍정/부정",
# "keywords: ["키워드1", "키워드2", "키워드3"]
# }},
# {{
# "news_id": "id2",
# "sentiment": "긍정/부정",
# "keywords: ["키워드1", "키워드2", "키워드3"]
# }},
# {{
# "news_id": "id3",
# "sentiment": "긍정/부정",
# "keywords: ["키워드1", "키워드2", "키워드3"]
# }},]
# }}
# ## 뉴스
# {news_text}
# """
db.close()
response = rag.get_response_from_prompt(
prompt_text, company=company, news_text=news_text
)
for r in response["news"]:
news_dict[r['news_id']]['sentiment'] = r['sentiment']
news_dict[r['news_id']]['keywords'] = r['keywords']
news_list = [{"news_id": news_id, "title": d['title'], "sub_title": d['sub_title'], "url": d['url'], "article_text": d["keywords"], "sentiment": d["sentiment"], "date": d["date"]} for news_id, d in news_dict.items()]
return {"news": news_list}
@app.get('/news/summary/{company_id}/{news_ids}')
async def get_summary(company_id:str, news_ids:str):
news_id = news_ids.split(',')
rag = Neo4jRAG()
db = DB()
company = db.query(
f"SELECT company FROM company WHERE company_id = '{company_id}'"
)[0]["company"]
sql = f"SELECT news_id, title, sub_title, url, date, article_text FROM news WHERE news_id IN ({', '.join(news_id)})"
news = db.query(sql)
news_text_list = [
f'news_id : {n["news_id"]}\ntitle : {n["title"]}\nsub_title : {n["sub_title"]}\ncontent : {n["article_text"]}'
for n in news
]
news_text = "\n\n".join(news_text_list)
prompt_text = """
# 지시문
당신은 10년차 증권사 애널리스트입니다. 특히 {company} 기업의 주가에 관심이 많아 관련 뉴스를 매일 확인하고 분석하여 요약합니다.
아래의 제약조건과 입력문을 토대로 최고의 결과를 출력해주세요. 100점 만점의 결과를 출력해주세요.
# 제약조건
1. 뉴스는 3가지 뉴스가 주어집니다.
2. 각 뉴스를 한꺼번에 요약해서 50자 이상 70자 이내로 출력해주세요.
3. 해당 요약본의 제목을 12자 이내로 출력해주세요.
4. 아래의 형식에 따라 JSON 형식으로 결과를 출력해 주세요. 그 외 어떠한 텍스트도 출력하지 마세요
## 형식
{{
"title": "제목",
"content": "요약본"
}}
## 뉴스
{news_text}
"""
response = rag.get_response_from_prompt(
prompt_text, company=company, news_text=news_text
)
db.close()
return response
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)