Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ openrelik-ai-common = "~0.5.0"
openrelik-common = "^0.7.6"
sse-starlette = "^2.3.6"
duckdb = "^1.3.2"
pytz = "^2026.1.post1"

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
Expand Down
1 change: 1 addition & 0 deletions src/api/v1/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def generate_query(
llm_model=active_llm["config"]["model"],
tables_schemas=json.dumps(tables_schemas),
user_request=request.user_request,
file=file,
)
return {"user_request": request.user_request, "generated_query": generated_query}

Expand Down
45 changes: 23 additions & 22 deletions src/datastores/sql/crud/authz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

from datastores.sql.models.file import File
from datastores.sql.models.folder import Folder
from datastores.sql.models.group import GroupRole
from datastores.sql.models.role import Role
from datastores.sql.models.user import User, UserRole
from datastores.sql.models.group import GroupRole


class AuthorizationError(Exception):
Expand Down Expand Up @@ -113,9 +113,7 @@ def check_user_access(
for group in user.groups:
group_role = (
db.query(GroupRole)
.filter(
GroupRole.group_id == group.id, GroupRole.folder_id == folder.id
)
.filter(GroupRole.group_id == group.id, GroupRole.folder_id == folder.id)
.first()
)
if group_role and group_role.role in allowed_roles:
Expand All @@ -126,12 +124,11 @@ def check_user_access(
return False # No access found


def require_access(
allowed_roles: list, http_exception: bool = True, error_message: str = None
):
def require_access(allowed_roles: list, http_exception: bool = True, error_message: str = None):
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):

def _check_access(kwargs):
"""Shared access check logic for both sync and async wrappers."""
db = kwargs.get("db")
folder_id = kwargs.get("folder_id")
file_id = kwargs.get("file_id")
Expand All @@ -141,9 +138,7 @@ async def wrapper(*args, **kwargs):
folder = db.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found.")
if not check_user_access(
db, current_user, allowed_roles, folder=folder
):
if not check_user_access(db, current_user, allowed_roles, folder=folder):
raise_authorization_error(
http_exception, error_message or "No access to folder"
)
Expand All @@ -152,18 +147,24 @@ async def wrapper(*args, **kwargs):
file = db.get(File, file_id)
if not file:
raise HTTPException(status_code=404, detail="File not found.")
if not check_user_access(
db, current_user, allowed_roles, file=file
):
raise_authorization_error(
http_exception, error_message or "No access to file"
)
# Await only if func is async
if asyncio.iscoroutinefunction(func):
if not check_user_access(db, current_user, allowed_roles, file=file):
raise_authorization_error(http_exception, error_message or "No access to file")

if asyncio.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args, **kwargs):
_check_access(kwargs)
return await func(*args, **kwargs)

return func(*args, **kwargs) # Call directly if func is sync
return async_wrapper
else:

@wraps(func)
def sync_wrapper(*args, **kwargs):
_check_access(kwargs)
return func(*args, **kwargs)

return wrapper
return sync_wrapper

return decorator
Loading