Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion supernote/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .config import ServerConfig
from .db.session import DatabaseSessionManager
from .routes import auth, file, oss, schedule, system
from .routes import admin, auth, file, oss, schedule, system
from .services.blob import LocalBlobStorage
from .services.coordination import SqliteCoordinationService
from .services.file import FileService
Expand Down Expand Up @@ -159,6 +159,7 @@ def create_app(config: ServerConfig) -> web.Application:

# Register routes
app.add_routes(system.routes)
app.add_routes(admin.routes)
app.add_routes(auth.routes)
app.add_routes(file.routes)
app.add_routes(oss.routes)
Expand Down
83 changes: 83 additions & 0 deletions supernote/server/routes/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from collections.abc import Awaitable, Callable

from aiohttp import web
from mashumaro.exceptions import MissingField

from supernote.models.auth import UserVO
from supernote.models.base import BaseResponse, create_error_response
from supernote.models.user import UserRegisterDTO
from supernote.server.services.user import UserService

routes = web.RouteTableDef()


def require_admin(
handler: Callable[[web.Request], Awaitable[web.Response]],
) -> Callable[[web.Request], Awaitable[web.Response]]:
"""Decorator to require admin privileges."""

async def wrapper(request: web.Request) -> web.Response:
user_service: UserService = request.app["user_service"]
username = request.get("user")
if not username:
return web.json_response(
create_error_response("Unauthorized").to_dict(), status=401
)

user = await user_service._get_user_do(str(username))
if not user or not user.is_admin:
return web.json_response(
create_error_response("Forbidden: Admin access required").to_dict(),
status=403,
)

return await handler(request)

return wrapper


@routes.post("/api/admin/users")
@require_admin
async def handle_create_user(request: web.Request) -> web.Response:
"""Create a new user (Admin only)."""
req_data = await request.json()
try:
dto = UserRegisterDTO.from_dict(req_data)
except (MissingField, ValueError):
return web.json_response(
create_error_response("Invalid request format").to_dict(),
status=400,
)

user_service: UserService = request.app["user_service"]
try:
await user_service.create_user(dto)
return web.json_response(BaseResponse().to_dict())
except ValueError as e:
return web.json_response(create_error_response(str(e)).to_dict(), status=400)


@routes.get("/api/admin/users")
@require_admin
async def handle_list_users(request: web.Request) -> web.Response:
"""List all users (Admin only)."""
user_service: UserService = request.app["user_service"]
users = await user_service.list_users()

user_vos = [
UserVO(
user_name=u.display_name or u.username,
email=u.email or u.username,
phone=u.phone or "",
country_code="1",
total_capacity=u.total_capacity,
file_server="0",
avatars_url=u.avatar or "",
birthday="",
sex="",
)
for u in users
]

# Simple list response for now
return web.json_response([vo.to_dict() for vo in user_vos])
55 changes: 37 additions & 18 deletions supernote/server/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jwt
from mashumaro.mixins.json import DataClassJSONMixin
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession

from supernote.models.auth import LoginVO, UserVO
from supernote.models.user import (
Expand Down Expand Up @@ -69,33 +70,51 @@ async def check_user_exists(self, account: str) -> bool:
result = await session.execute(stmt)
return result.scalar_one_or_none() is not None

async def _create_user_entry(
self, session: AsyncSession, dto: UserRegisterDTO, is_admin: bool = False
) -> UserDO:
"""Internal helper to insert user into DB."""
if await self.check_user_exists(dto.email):
raise ValueError("User already exists")

# Hash password before storage.
# Future improvement: Upgrade to stronger hashing (e.g., bcrypt/argon2).
password_md5 = hashlib.md5(dto.password.encode()).hexdigest()

new_user = UserDO(
username=dto.email,
email=dto.email,
password_md5=password_md5,
display_name=dto.user_name,
is_active=True,
is_admin=is_admin,
)
session.add(new_user)
# Flush to get ID, but let caller commit
await session.flush()
return new_user

async def register(self, dto: UserRegisterDTO) -> UserDO:
"""Register a new user."""
"""Register a new user (Public/Self-Service)."""
async with self._session_manager.session() as session:
# Check for bootstrapping condition (no users exist). We allow registration if there are no users.
# even when registration is disabled when bootstrapping.
# Check for bootstrapping condition (no users exist)
user_count = (await session.execute(select(func.count(UserDO.id)))).scalar()
is_bootstrap = user_count == 0

if not self._config.enable_registration and not is_bootstrap:
raise ValueError("Registration is disabled")

if await self.check_user_exists(dto.email):
raise ValueError("User already exists")

# Hash password before storage.
# Future improvement: Upgrade to stronger hashing (e.g., bcrypt/argon2).
password_md5 = hashlib.md5(dto.password.encode()).hexdigest()

new_user = UserDO(
username=dto.email,
email=dto.email,
password_md5=password_md5,
display_name=dto.user_name,
is_active=True,
is_admin=is_bootstrap,
new_user = await self._create_user_entry(
session, dto, is_admin=is_bootstrap
)
session.add(new_user)
await session.commit()
await session.refresh(new_user)
return new_user

async def create_user(self, dto: UserRegisterDTO) -> UserDO:
"""Create a new user (Admin/System). Skips registration enabled check."""
async with self._session_manager.session() as session:
new_user = await self._create_user_entry(session, dto, is_admin=False)
await session.commit()
await session.refresh(new_user)
return new_user
Expand Down
20 changes: 20 additions & 0 deletions tests/server/services/test_user_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,23 @@ async def test_bootstrap_bypasses_disabled_registration(
)
with pytest.raises(ValueError, match="Registration is disabled"):
await service.register(dto2)


@pytest.mark.asyncio
async def test_admin_create_user_bypass(
session_manager: DatabaseSessionManager, coordination_service: CoordinationService
) -> None:
"""Test that create_user allows creating users when disabled."""
config = AuthConfig(enable_registration=False)
service = UserService(config, coordination_service, session_manager)

# Bootstrap first
await service.register(
UserRegisterDTO(email="admin@example.com", password="pw", user_name="Admin")
)

# Explicitly use create_user (Admin action simulation)
dto2 = UserRegisterDTO(email="new@example.com", password="pw", user_name="New")
user2 = await service.create_user(dto2)

assert user2.email == "new@example.com"
127 changes: 127 additions & 0 deletions tests/server/test_admin_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Any

import jwt
import pytest
from sqlalchemy import delete

from supernote.client.client import Client
from supernote.models.user import UserRegisterDTO
from supernote.server.config import ServerConfig
from supernote.server.db.models.user import UserDO
from supernote.server.db.session import DatabaseSessionManager
from supernote.server.services.coordination import CoordinationService
from supernote.server.services.user import JWT_ALGORITHM, UserService


@pytest.fixture
def admin_headers(server_config: ServerConfig) -> dict[str, Any]:
"""Headers for an ADMIN user."""
secret = server_config.auth.secret_key
token = jwt.encode({"sub": "admin@example.com"}, secret, algorithm=JWT_ALGORITHM)
return {"x-access-token": token}


@pytest.fixture
def user_headers(server_config: ServerConfig) -> dict[str, Any]:
"""Headers for a NORMAL user."""
secret = server_config.auth.secret_key
token = jwt.encode({"sub": "user@example.com"}, secret, algorithm=JWT_ALGORITHM)
return {"x-access-token": token}


async def setup_users(
session_manager: DatabaseSessionManager,
coordination_service: CoordinationService,
server_config: ServerConfig,
) -> None:
"""Helper to setup database state."""
async with session_manager.session() as session:
await session.execute(delete(UserDO))
await session.commit()

service = UserService(server_config.auth, coordination_service, session_manager)

# Register Admin (Bootstrapping)
await service.register(
UserRegisterDTO(email="admin@example.com", password="pw", user_name="Admin")
)

# Register Normal User (via Admin creation to ensure consistency,
# though strict bootstrapping only allows the FIRST user to be admin,
# so we use create_user for the second one if we wanted,
# but here we just need to ensure the DB state is correct).
# Easier: manually set is_admin=False for the second user if needed,
# but register() naturally makes 2nd user non-admin.
await service.register(
UserRegisterDTO(email="user@example.com", password="pw", user_name="User")
)

# Store sessions in coordination service
secret = server_config.auth.secret_key
admin_token = jwt.encode(
{"sub": "admin@example.com"}, secret, algorithm=JWT_ALGORITHM
)
user_token = jwt.encode(
{"sub": "user@example.com"}, secret, algorithm=JWT_ALGORITHM
)

await coordination_service.set_value(
f"session:{admin_token}", "admin@example.com|", ttl=3600
)
await coordination_service.set_value(
f"session:{user_token}", "user@example.com|", ttl=3600
)


async def test_admin_list_users_permission(
client: Client,
session_manager: DatabaseSessionManager,
coordination_service: CoordinationService,
server_config: ServerConfig,
admin_headers: dict[str, str],
user_headers: dict[str, str],
) -> None:
"""Test access control for listing users."""
await setup_users(session_manager, coordination_service, server_config)

# 1. Admin should succeed
resp = await client.get("/api/admin/users", headers=admin_headers)
assert resp.status == 200
data = await resp.json()
assert len(data) >= 2

# 2. Normal user should fail
resp = await client.get("/api/admin/users", headers=user_headers)
assert resp.status == 403

# 3. Anon should fail
resp = await client.get("/api/admin/users")
assert resp.status == 401


async def test_admin_create_user(
client: Client,
session_manager: DatabaseSessionManager,
coordination_service: CoordinationService,
server_config: ServerConfig,
admin_headers: dict[str, str],
) -> None:
"""Test admin creating a user."""
await setup_users(session_manager, coordination_service, server_config)

new_user = {
"email": "newbie@example.com",
"userName": "Newbie",
"password": "password",
"countryCode": "1",
}

# Admin creates user
resp = await client.post("/api/admin/users", json=new_user, headers=admin_headers)
assert resp.status == 200

# Verify user exists
resp = await client.get("/api/admin/users", headers=admin_headers)
data = await resp.json()
emails = [u["email"] for u in data]
assert "newbie@example.com" in emails