-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag.py
More file actions
145 lines (116 loc) · 6.34 KB
/
rag.py
File metadata and controls
145 lines (116 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import logging
from typing import List, Dict, Any
from openai import OpenAI
from config import OPENAI_API_KEY, RAG_SETTINGS, MODELS
from retrieval import rerank_items as search_similar_items
from db import get_items_sample
import tiktoken
from debug_utils import debug_step
from openai_api_models import client
from utils import timeit
logger = logging.getLogger(__name__)
def num_tokens_from_string(string: str, model: str = None) -> int:
"""Возвращает количество токенов в строке"""
if model is None:
model = MODELS['generation']['name']
try:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(string))
except KeyError:
logger.warning(f"Модель {model} не найдена, используем gpt-3.5-turbo")
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
return len(encoding.encode(string))
def truncate_text(text: str, max_tokens: int = None) -> str:
"""Обрезает текст до указанного количества токенов"""
if max_tokens is None:
max_tokens = MODELS['generation']['max_tokens']
logger.debug(f"Обрезаем текст. Исходная длина: {len(text)} символов")
encoding = tiktoken.encoding_for_model(MODELS['generation']['name'])
tokens = encoding.encode(text)
logger.debug(f"Количество токенов: {len(tokens)}")
if len(tokens) <= max_tokens:
return text
truncated = encoding.decode(tokens[:max_tokens])
logger.debug(f"Текст обрезан. Новая длина: {len(truncated)} символов")
return truncated
@timeit
def generate_prompt(query: str, context_items: List[Dict[str, Any]]) -> str:
"""Генерирует промпт для модели"""
logger.debug(f"Генерация промпта для запроса: {query}")
# Получаем параметры контекста (возможно, обновленные пользователем)
context_params = debug_step('context') or RAG_SETTINGS
# Ограничиваем количество токенов для каждого контекста
max_tokens_per_context = context_params['chunk_size']
context_texts = []
for i, item in enumerate(context_items):
logger.debug(f"Обработка контекста {i+1}")
context = item['text']
truncated_context = truncate_text(context, max_tokens_per_context)
context_texts.append(f"Context {i+1}:\n{truncated_context}")
context_str = "\n\n".join(context_texts)
total_tokens = num_tokens_from_string(context_str)
logger.debug(f"Общее количество токенов в контексте: {total_tokens}")
prompt = f"""Используй следующий контекст для ответа на вопрос.
Если информации недостаточно, скажи об этом.
{context_str}
Вопрос: {query}
Ответ:"""
# Отладка: показываем собранный контекст и промпт
debug_step('context', {
'context_count': len(context_texts),
'total_tokens': total_tokens,
'context': context_str
})
debug_step('generation', {
'model': MODELS['generation']['name'],
'prompt': prompt,
'tokens': num_tokens_from_string(prompt)
})
return prompt
@timeit
def generate_answer(query: str, context_items: List[Dict[str, Any]]) -> str:
"""Генерирует ответ на основе контекста"""
logger.debug("Генерация ответа")
# Проверяем, есть ли реальный контекст
has_real_context = False
for item in context_items:
if not item['text'].startswith("Информация отсутствует"):
has_real_context = True
break
# Если нет реального контекста, сообщаем об этом
if not has_real_context:
logger.warning("Отсутствует релевантный контекст, используем более простую модель")
return "Извините, в базе знаний не найдено информации по вашему запросу. Пожалуйста, уточните вопрос или используйте другие ключевые слова."
prompt = generate_prompt(query, context_items)
# Получаем параметры генерации (возможно, обновленные пользователем)
gen_params = debug_step('generation') or RAG_SETTINGS
total_tokens = num_tokens_from_string(prompt)
logger.debug(f"Общее количество токенов в промпте: {total_tokens}")
# Оставляем запас для ответа
max_prompt_tokens = MODELS['generation']['max_tokens'] - gen_params['max_tokens']
if total_tokens > max_prompt_tokens:
logger.warning(f"Превышен безопасный лимит токенов: {total_tokens}")
prompt = truncate_text(prompt, max_prompt_tokens)
logger.debug("Промпт обрезан до безопасного размера")
try:
response = client.chat.completions.create(
model=MODELS['generation']['name'],
messages=[
{"role": "system", "content": "Ты помощник, который отвечает на вопросы, используя предоставленный контекст."},
{"role": "user", "content": prompt}
],
temperature=gen_params['temperature'],
max_tokens=gen_params['max_tokens']
)
logger.debug("Ответ получен от API")
answer = response.choices[0].message.content
# Оставить только этот один вызов в конце
debug_step('generation', {
'model': MODELS['generation']['name'],
'answer_tokens': num_tokens_from_string(answer),
'answer': answer
})
return answer
except Exception as e:
logger.error(f"Ошибка при получении ответа от API: {str(e)}")
raise