-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathragflow_client.py
More file actions
254 lines (218 loc) · 7.69 KB
/
ragflow_client.py
File metadata and controls
254 lines (218 loc) · 7.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from ragflow_sdk import RAGFlow
from openai import OpenAI
import os
class RAGFlowClient:
def __init__(self, api_key="ragflow-Q4ZWNiMjE4MWE4NTExZjBiODlmMzZiNj", base_url="http://localhost:8080"):
"""初始化RAGFlow客户端
Args:
api_key: RAGFlow API密钥
base_url: RAGFlow服务器URL
"""
self.rag = RAGFlow(api_key=api_key, base_url=base_url)
def create_dataset(self, name, description="", embedding_model="BAAI/bge-large-zh-v1.5",
permission="me", chunk_method="naive"):
"""创建数据集
Args:
name: 数据集名称
description: 数据集描述
embedding_model: 嵌入模型
permission: 权限设置("me"或"team")
chunk_method: 分块方法
Returns:
创建的数据集对象
"""
try:
dataset = self.rag.create_dataset(
name=name,
description=description,
embedding_model=embedding_model,
permission=permission,
chunk_method=chunk_method
)
print(f"成功创建数据集: {name}")
return dataset
except Exception as e:
print(f"创建数据集失败: {e}")
return None
def list_datasets(self, page=1, page_size=30, orderby="create_time", desc=True, id=None, name=None):
"""列出数据集
Args:
page: 页码
page_size: 每页数量
orderby: 排序字段
desc: 是否降序排序
id: 数据集ID
name: 数据集名称
Returns:
数据集列表
"""
try:
datasets = self.rag.list_datasets(
page=page,
page_size=page_size,
orderby=orderby,
desc=desc,
id=id,
name=name
)
return datasets
except Exception as e:
print(f"获取数据集列表失败: {e}")
return []
def delete_datasets(self, ids=None):
"""删除数据集
Args:
ids: 要删除的数据集ID列表
"""
try:
self.rag.delete_datasets(ids=ids)
print(f"成功删除数据集")
except Exception as e:
print(f"删除数据集失败: {e}")
def upload_document(self, dataset, file_path):
"""上传文档到数据集
Args:
dataset: 数据集对象
file_path: 文件路径
Returns:
是否上传成功
"""
try:
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, 'rb') as f:
doc_content = f.read()
dataset.upload_documents([{
"display_name": os.path.basename(file_path),
"blob": doc_content
}])
print(f"成功上传文件: {file_path}")
return True
except Exception as e:
print(f"上传文档失败: {e}")
return False
def parse_documents(self, dataset, document_ids):
"""解析文档
Args:
dataset: 数据集对象
document_ids: 文档ID列表
"""
try:
dataset.async_parse_documents(document_ids)
print("文档解析任务已启动")
except Exception as e:
print(f"启动文档解析失败: {e}")
def create_chat(self, name, dataset_ids, llm=None, prompt=None):
"""创建聊天助手
Args:
name: 聊天助手名称
dataset_ids: 数据集ID列表
llm: LLM设置(可选)
prompt: 提示设置(可选)
Returns:
创建的聊天助手对象
"""
try:
chat = self.rag.create_chat(
name=name,
dataset_ids=dataset_ids,
llm=llm,
prompt=prompt
)
print(f"成功创建聊天助手: {name}")
return chat
except Exception as e:
print(f"创建聊天助手失败: {e}")
return None
def start_chat_session(self, chat):
"""开始聊天会话
Args:
chat: 聊天助手对象
"""
try:
session = chat.create_session("新会话")
print("聊天会话已开始,输入'退出'结束对话")
while True:
try:
question = input("\n用户: ")
if question.lower() in ['退出', 'quit', 'exit']:
break
print("\n助手: ", end='')
response = ""
for chunk in session.ask(question, stream=True):
new_content = chunk.content[len(response):]
print(new_content, end='', flush=True)
response = chunk.content
print()
except KeyboardInterrupt:
print("\n用户中断对话")
break
except Exception as e:
print(f"聊天过程中出错: {e}")
continue
except Exception as e:
print(f"创建会话失败: {e}")
def retrieve(self, question, dataset_ids, document_ids=None, page=1, page_size=30,
similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024):
"""从数据集中检索内容
Args:
question: 查询问题
dataset_ids: 数据集ID列表
document_ids: 文档ID列表(可选)
page: 页码
page_size: 每页数量
similarity_threshold: 相似度阈值
vector_similarity_weight: 向量相似度权重
top_k: 参与向量计算的片段数量
Returns:
检索到的片段列表
"""
try:
chunks = self.rag.retrieve(
question=question,
dataset_ids=dataset_ids,
document_ids=document_ids,
page=page,
page_size=page_size,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top_k=top_k
)
return chunks
except Exception as e:
print(f"检索失败: {e}")
return []
def create_openai_client(self, chat_id):
"""创建OpenAI兼容的客户端
Args:
chat_id: 聊天助手ID
Returns:
OpenAI客户端对象
"""
try:
client = OpenAI(
api_key="ragflow-Q4ZWNiMjE4MWE4NTExZjBiODlmMzZiNj",
base_url=f"{self.rag.base_url}/api/v1/chats_openai/{chat_id}"
)
return client
except Exception as e:
print(f"创建OpenAI客户端失败: {e}")
return None
# 使用示例
if __name__ == "__main__":
# 初始化客户端
client = RAGFlowClient()
# 列出所有数据集
datasets = client.list_datasets()
for ds in datasets:
print(f"数据集: {ds.name}, ID: {ds.id}")
# 创建数据集
# dataset = client.create_dataset("测试数据集", "这是一个测试数据集")
# 上传文档
# if dataset:
# client.upload_document(dataset, "test_data.txt")
# 创建聊天助手
# if datasets:
# chat = client.create_chat("测试助手", [datasets[0].id])
# client.start_chat_session(chat)