-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag.py
More file actions
151 lines (124 loc) · 4.87 KB
/
rag.py
File metadata and controls
151 lines (124 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import json
import os
from typing import cast
import chromadb
import pandas as pd
from chromadb.utils import embedding_functions
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
class RAGSystem:
"""
법률 문제를 풀기 위한 RAG 시스템
"""
def __init__(self) -> None:
# 1. 벡터 DB 초기화
self.client = chromadb.PersistentClient(path="./chroma_db")
# 2. 임베딩 모델 설정
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
# 3. 컬렉션 설정
self.collection = self.client.get_or_create_collection(
name="legal_data",
embedding_function=self.openai_ef, # type: ignore
)
# 4. LLM 클라이언트 초기화
self.llm_client = OpenAI()
def ingest_data(self, csv_path: str) -> None:
"""
CSV 데이터를 읽어서 벡터 DB에 적재하는 함수입니다.
"""
if self.collection.count() > 0:
print(f"데이터 이미 적재되어 있음: {self.collection.count()}")
return
print("데이터 적재 중")
df = pd.read_csv(csv_path)
documents: list[str] = []
ids: list[str] = []
metadatas: list[dict[str, str | int | float | bool]] = []
# 데이터 전처리
for idx, row in df.iterrows():
answer_val: str | int | float | bool | None = cast(
str | int | float | bool | None, row["answer"]
)
answer_value: str | int | float | bool = (
str(answer_val) if answer_val is not None else ""
)
try:
category_val: str | int | float | bool | None = cast(
str | int | float | bool | None, row["Category"]
)
category_value: str | int | float | bool = (
str(category_val) if category_val is not None else ""
)
except KeyError:
category_value = ""
text = (
f"ID: {idx}\n"
f"질문: {row['question']}\n"
f"A: {row['A']}\n"
f"B: {row['B']}\n"
f"C: {row['C']}\n"
f"D: {row['D']}\n"
f"정답: {answer_value}\n"
)
documents.append(text)
ids.append(f"id_{idx}")
metadatas.append({"answer": answer_value, "category": category_value})
# 배치 처리
batch_size = 100
for i in range(0, len(documents), batch_size):
self.collection.add(
documents=documents[i : i + batch_size],
metadatas=metadatas[i : i + batch_size], # type: ignore
ids=ids[i : i + batch_size],
)
print("데이터 적재 완료")
def get_answer(self, query: str) -> str:
"""
사용자 질문에 대한 검색 및 생성을 통해 정답을 반환하는 함수입니다.
"""
# 1. Retrieve
results = self.collection.query(query_texts=[query], n_results=8)
if not results["documents"] or not results["documents"][0]:
context = "관련된 문제를 찾을 수 없습니다."
else:
context = "\n\n".join(results["documents"][0])
# 2. Prompt Engineering
prompt = f"""
# Role
당신은 대한민국 법률 전문가입니다.
# Context (유사 기출문제 족보)
아래는 과거에 출제된 문제와 정답입니다.
{context}
# Problem
{query}
# Instructions
1. [Context]를 분석하여, 이번 문제와 **유사한 판례나 법리**가 있는지 확인하세요.
2. **[Context]에 동일한 문제나 유사한 사례가 있다면, 그 정답(족보)을 최우선으로 따르세요.** (매우 중요)
3. 족보에 없다면, 당신의 법률 지식을 사용하여 정답을 추론하세요.
4. 정답은 반드시 A, B, C, D 중 하나만 선택하세요.
# Output Format (JSON Only)
{{
"reasoning": "정답 도출 근거 요약",
"answer": "A"
}}
"""
try:
# 3. Generate
response = self.llm_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0,
seed=42,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
if content is None:
return ""
res_json = json.loads(content)
return res_json.get("answer", "").strip().upper()
except Exception as e:
print(f"Error: {e}")
return ""