Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 49 additions & 86 deletions api/db/services/knowledgebase_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class KnowledgebaseService(CommonService):
Attributes:
model: The Knowledgebase model class for database operations.
"""

model = Knowledgebase

@classmethod
Expand Down Expand Up @@ -75,8 +76,7 @@ def accessible4deletion(cls, kb_id, user_id):
2. The user is not the creator of the dataset
"""
# Check if a dataset can be deleted by a user
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = cls.model.select(cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
Expand Down Expand Up @@ -108,10 +108,10 @@ def is_parsed_done(cls, kb_id):
# Check parsing status of each document
for doc in docs:
# If document is being parsed, don't allow chat creation
if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
if doc["run"] == TaskStatus.RUNNING.value or doc["run"] == TaskStatus.CANCEL.value or doc["run"] == TaskStatus.FAIL.value:
return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
# If document is not yet parsed and has no chunks, don't allow chat creation
if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
if doc["run"] == TaskStatus.UNSTART.value and doc["chunk_num"] == 0:
return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."

return True, None
Expand All @@ -124,20 +124,14 @@ def list_documents_by_ids(cls, kb_ids):
# kb_ids: List of dataset IDs
# Returns:
# List of document IDs
doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
cls.model.id.in_(kb_ids)
)
doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(cls.model.id.in_(kb_ids))
doc_ids = list(doc_ids.dicts())
doc_ids = [doc["document_id"] for doc in doc_ids]
return doc_ids

@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page,
orderby, desc, keywords,
parser_id=None
):
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
# Get knowledge bases by tenant IDs with pagination and filtering
# Args:
# joined_tenant_ids: List of tenant IDs
Expand All @@ -164,23 +158,27 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
cls.model.parser_id,
cls.model.embd_id,
User.nickname,
User.avatar.alias('tenant_avatar'),
cls.model.update_time
User.avatar.alias("tenant_avatar"),
cls.model.update_time,
]
if keywords:
kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
kbs = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value),
(fn.LOWER(cls.model.name).contains(keywords.lower())),
)
)
else:
kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
kbs = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
)
if parser_id:
kbs = kbs.where(cls.model.parser_id == parser_id)
Expand Down Expand Up @@ -210,27 +208,13 @@ def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
cls.model.chunk_num,
cls.model.status,
cls.model.create_date,
cls.model.update_date
cls.model.update_date,
]
# find team kb and owned kb
kbs = cls.model.select(*fields).where(
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id
)
)
kbs = cls.model.select(*fields).where((cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
# sort by create_time asc
kbs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later.
offset, limit = 0, 50
res = []
while True:
kb_batch = kbs.offset(offset).limit(limit)
_temp = list(kb_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
kbs = kbs.order_by(cls.model.create_time.asc())
return list(kbs.dicts())

@classmethod
@DB.connection_context()
Expand Down Expand Up @@ -279,14 +263,14 @@ def get_detail(cls, kb_id):
cls.model.mindmap_task_id,
cls.model.mindmap_task_finish_at,
cls.model.create_time,
cls.model.update_time
]
kbs = cls.model.select(*fields)\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
).dicts()
cls.model.update_time,
]
kbs = (
cls.model.select(*fields)
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
.where((cls.model.id == kb_id), (cls.model.status == StatusEnum.VALID.value))
.dicts()
)
if not kbs:
return None
return kbs[0]
Expand Down Expand Up @@ -353,11 +337,7 @@ def get_by_name(cls, kb_name, tenant_id):
# tenant_id: Tenant ID
# Returns:
# Tuple of (exists, knowledge_base)
kb = cls.model.select().where(
(cls.model.name == kb_name)
& (cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
kb = cls.model.select().where((cls.model.name == kb_name) & (cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
if kb:
return True, kb[0]
return False, None
Expand All @@ -370,17 +350,9 @@ def get_all_ids(cls):
# List of all dataset IDs
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]


@classmethod
@DB.connection_context()
def create_with_name(
cls,
*,
name: str,
tenant_id: str,
parser_id: str | None = None,
**kwargs
):
def create_with_name(cls, *, name: str, tenant_id: str, parser_id: str | None = None, **kwargs):
"""Create a dataset (knowledgebase) by name with kb_app defaults.

This encapsulates the creation logic used in kb_app.create so other callers
Expand Down Expand Up @@ -420,7 +392,7 @@ def create_with_name(
"tenant_id": tenant_id,
"created_by": tenant_id,
"parser_id": (parser_id or "naive"),
**kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
**kwargs, # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
}

# Update parser_config (always override with validated default/merged config)
Expand All @@ -429,11 +401,9 @@ def create_with_name(

return True, payload


@classmethod
@DB.connection_context()
def get_list(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, id, name):
def get_list(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, id, name):
# Get list of knowledge bases with filtering and pagination
# Args:
# joined_tenant_ids: List of tenant IDs
Expand All @@ -453,10 +423,7 @@ def get_list(cls, joined_tenant_ids, user_id,
if name:
kbs = kbs.where(cls.model.name == name)
kbs = kbs.where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)
)

if desc:
Expand All @@ -478,9 +445,7 @@ def accessible(cls, kb_id, user_id):
# user_id: User ID
# Returns:
# Boolean indicating accessibility
docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = cls.model.select(cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
Expand All @@ -495,8 +460,7 @@ def get_kb_by_id(cls, kb_id, user_id):
# user_id: User ID
# Returns:
# List containing dataset information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts()
return list(kbs)

Expand All @@ -509,8 +473,7 @@ def get_kb_by_name(cls, kb_name, user_id):
# user_id: User ID
# Returns:
# List containing dataset information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts()
return list(kbs)

Expand Down Expand Up @@ -546,7 +509,7 @@ def update_document_number_in_init(cls, kb_id, doc_num):
kb.save(only=dirty_fields)
except ValueError as e:
if str(e) == "no data to save!":
pass # that's OK
pass # that's OK
else:
raise e

Expand All @@ -557,10 +520,10 @@ def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
if not kb_row:
raise RuntimeError(f"kb_id {kb_id} does not exist")
update_dict = {
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
'token_num': kb_row.token_num - doc_num_info['token_num'],
'update_time': current_timestamp(),
'update_date': datetime_format(datetime.now())
"doc_num": kb_row.doc_num - doc_num_info["doc_num"],
"chunk_num": kb_row.chunk_num - doc_num_info["chunk_num"],
"token_num": kb_row.token_num - doc_num_info["token_num"],
"update_time": current_timestamp(),
"update_date": datetime_format(datetime.now()),
}
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
45 changes: 45 additions & 0 deletions test/unit_test/services/test_knowledgebase_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest.mock import MagicMock, patch
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import DB


class TestKnowledgebaseServicePerf:
@patch.object(DB, "connect")
@patch.object(DB, "close")
@patch.object(KnowledgebaseService.model, "select")
def test_get_all_kb_by_tenant_ids_pagination(self, mock_select, mock_close, mock_connect):
# Arrange
tenant_id = "tenant-1"
user_id = "user-1"

# Mock the query chain
mock_query = MagicMock()
mock_select.return_value = mock_query
mock_query.where.return_value = mock_query

# Mock order_by return value
mock_sorted_query = MagicMock()
mock_sorted_query.dicts.return_value = [{"id": "1"}] # The result
mock_query.order_by.return_value = mock_sorted_query

# Mock kbs (the result of where())
mock_kbs = mock_query

# We don't expect offset or limit to be called anymore,
# so we don't strictly need to mock their side effects for success,
# but the assertion will check they are NOT called.

# Act
res = KnowledgebaseService.get_all_kb_by_tenant_ids([tenant_id], user_id)

# Assert
assert len(res) == 1
assert res[0]["id"] == "1"

# Verify offset was NOT called (proving the loop is gone)
assert mock_kbs.offset.call_count == 0
assert mock_kbs.limit.call_count == 0

# Verify correct chaining
mock_query.order_by.assert_called_once()
mock_sorted_query.dicts.assert_called_once()