-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
195 lines (117 loc) · 4.81 KB
/
server.py
File metadata and controls
195 lines (117 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
from pathlib import Path
from datetime import datetime
try:
from .src import ocr_to_dict , image_to_ocr , chatbot , retriever
from .src import auth as auth
except:
from src import ocr_to_dict , image_to_ocr , chatbot , retriever
from src import auth as auth
from typing import Dict
from fastapi import FastAPI , Request
import asyncio
from pydantic import BaseModel
import jwt
from fastapi import UploadFile, File, Form
from typing import List , Union
import logging
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
# Basic logging setup
logging.basicConfig(
level=logging.INFO, # Change to DEBUG for even more verbosity
format="%(asctime)s [%(levelname)s] %(message)s",
filename="server.log",
filemode="a",
)
logger = logging.getLogger(__name__)
class ChatTurn(BaseModel):
user: str
chatbot: Union[str, Dict[str, str]]
class ChatRequest(BaseModel):
message : str
memory: List[ChatTurn]
class loginRegisterPayload(BaseModel):
payload : str
IMAGE_DIR = "data/img_database"
ocr_model = image_to_ocr.OCRModel()
ocr_parser = ocr_to_dict.OCRParser()
contact_retriever = retriever.Retriever()
llm = None
llm_cache = {}
app = FastAPI()
@app.get("/")
async def chatbot_init():
flags = {"ocr_model":False , "ocr_parser":False , "contact_retriever":False , "llm":False}
if ocr_model:
flags["ocr_model"] = True
if ocr_parser:
flags['ocr_parser'] = True
if contact_retriever:
flags["contact_retriever"] = True
if llm:
flags["llm"] = True
logger.info("System component flags: %s", flags)
return flags
@app.post("/login-register")
async def login_register(payload : loginRegisterPayload):
db_id = "faiss_index"
logger.info("Received login/register request")
decoded_payload = auth.decode(json_payload=payload.payload)
logger.info(f"Decoded JWT payload: {decoded_payload}")
if decoded_payload['action'] == "login":
login_json_response = auth.login(userid=decoded_payload['userid'] , password=decoded_payload["password"])
db_id = login_json_response['db_id']
logger.info(f"User {decoded_payload['userid']} logged in. DB ID: {db_id}")
elif decoded_payload['action'] == 'register':
register_json_response = auth.register(userid=decoded_payload['userid'] , password=decoded_payload["password"])
db_id = register_json_response["db_id"]
logger.info(f"User {decoded_payload['userid']} registered. DB ID: {db_id}")
return {"db_id":db_id}
@app.get("/user/{db_id}")
async def retrieve(db_id):
logger.info(f"Loading retriever for db_id={db_id}")
contact_retriever.load_index(db_id)
llm = chatbot.Chatbot(retriever=contact_retriever)
llm_cache[db_id] = llm
logger.info(f"Chatbot loaded and cached for db_id={db_id}")
return {"response": f"retirever loaded for user databasse{db_id}"}
@app.post("/users/{db_id}/query")
async def query_retriever(db_id:str , query : ChatRequest):
user_message = query.message
memory = query.memory
logger.info(f"Received query for db_id={db_id}: {user_message} || length of conversation memory {len(memory)}")
if db_id not in llm_cache:
logger.warning(f"LLM not loaded for db_id={db_id}")
return {"error":"user chatot not loaded"}
llm = llm_cache[db_id]
response = llm.chatcompletion(query=user_message , memory = memory )
logger.info(f"Chatbot response: {response}")
return {"response" : response}
@app.post("/users/{db_id}/add_contacts")
async def add_to_retriever(
db_id: str,
contact_images: List[UploadFile] = File(...),
):
image_paths = []
ocr_texts = []
logger.info(f"Received {len(contact_images)} image(s) for db_id: {db_id}")
for image_file in contact_images:
contents = await image_file.read()
file_path = f"{IMAGE_DIR}/{image_file.filename}"
with open(file_path, "wb") as f:
f.write(contents)
image_paths.append(file_path)
logger.info(f"Saved uploaded file: {file_path}")
logger.info("Running OCR on uploaded images...")
ocr_texts, _ = ocr_model.getOCRtext(image_paths=image_paths)
logger.info("Parsing OCR text to contact dictionaries...")
contact_dicts = ocr_parser.parseOCRTexts(ocr_texts=ocr_texts, image_paths=image_paths)
logger.info(f"Loading contact retriever for db_id={db_id}")
contact_retriever.load_index(index_path=db_id)
logger.info(f"Adding {len(contact_dicts)} contact(s) to retriever")
contact_retriever.add_data(data_list=contact_dicts)
logger.info(f"Successfully added contacts for db_id={db_id}")
return {"status": "contacts added", "count": len(contact_dicts)}
# @app.post()