diff --git a/supernote/server/app.py b/supernote/server/app.py index ad0923f..fd55bbd 100644 --- a/supernote/server/app.py +++ b/supernote/server/app.py @@ -11,7 +11,7 @@ from .config import ServerConfig from .db.session import DatabaseSessionManager -from .routes import auth, file, oss, schedule, system +from .routes import admin, auth, file, oss, schedule, system from .services.blob import LocalBlobStorage from .services.coordination import SqliteCoordinationService from .services.file import FileService @@ -159,6 +159,7 @@ def create_app(config: ServerConfig) -> web.Application: # Register routes app.add_routes(system.routes) + app.add_routes(admin.routes) app.add_routes(auth.routes) app.add_routes(file.routes) app.add_routes(oss.routes) diff --git a/supernote/server/routes/admin.py b/supernote/server/routes/admin.py new file mode 100644 index 0000000..65f5162 --- /dev/null +++ b/supernote/server/routes/admin.py @@ -0,0 +1,83 @@ +from collections.abc import Awaitable, Callable + +from aiohttp import web +from mashumaro.exceptions import MissingField + +from supernote.models.auth import UserVO +from supernote.models.base import BaseResponse, create_error_response +from supernote.models.user import UserRegisterDTO +from supernote.server.services.user import UserService + +routes = web.RouteTableDef() + + +def require_admin( + handler: Callable[[web.Request], Awaitable[web.Response]], +) -> Callable[[web.Request], Awaitable[web.Response]]: + """Decorator to require admin privileges.""" + + async def wrapper(request: web.Request) -> web.Response: + user_service: UserService = request.app["user_service"] + username = request.get("user") + if not username: + return web.json_response( + create_error_response("Unauthorized").to_dict(), status=401 + ) + + user = await user_service._get_user_do(str(username)) + if not user or not user.is_admin: + return web.json_response( + create_error_response("Forbidden: Admin access required").to_dict(), + status=403, + ) + + return await handler(request) + + return wrapper + + +@routes.post("/api/admin/users") +@require_admin +async def handle_create_user(request: web.Request) -> web.Response: + """Create a new user (Admin only).""" + req_data = await request.json() + try: + dto = UserRegisterDTO.from_dict(req_data) + except (MissingField, ValueError): + return web.json_response( + create_error_response("Invalid request format").to_dict(), + status=400, + ) + + user_service: UserService = request.app["user_service"] + try: + await user_service.create_user(dto) + return web.json_response(BaseResponse().to_dict()) + except ValueError as e: + return web.json_response(create_error_response(str(e)).to_dict(), status=400) + + +@routes.get("/api/admin/users") +@require_admin +async def handle_list_users(request: web.Request) -> web.Response: + """List all users (Admin only).""" + user_service: UserService = request.app["user_service"] + users = await user_service.list_users() + + user_vos = [ + UserVO( + user_name=u.display_name or u.username, + email=u.email or u.username, + phone=u.phone or "", + country_code="1", + total_capacity=u.total_capacity, + file_server="0", + avatars_url=u.avatar or "", + birthday="", + sex="", + ) + for u in users + ] + + # Simple list response for now + return web.json_response([vo.to_dict() for vo in user_vos]) diff --git a/supernote/server/services/user.py b/supernote/server/services/user.py index 187a49c..994a90d 100644 --- a/supernote/server/services/user.py +++ b/supernote/server/services/user.py @@ -9,6 +9,7 @@ import jwt from mashumaro.mixins.json import DataClassJSONMixin from sqlalchemy import delete, func, select, update +from sqlalchemy.ext.asyncio import AsyncSession from supernote.models.auth import LoginVO, UserVO from supernote.models.user import ( @@ -69,33 +70,51 @@ async def check_user_exists(self, account: str) -> bool: result = await session.execute(stmt) return result.scalar_one_or_none() is not None + async def _create_user_entry( + self, session: AsyncSession, dto: UserRegisterDTO, is_admin: bool = False + ) -> UserDO: + """Internal helper to insert user into DB.""" + if await self.check_user_exists(dto.email): + raise ValueError("User already exists") + + # Hash password before storage. + # Future improvement: Upgrade to stronger hashing (e.g., bcrypt/argon2). + password_md5 = hashlib.md5(dto.password.encode()).hexdigest() + + new_user = UserDO( + username=dto.email, + email=dto.email, + password_md5=password_md5, + display_name=dto.user_name, + is_active=True, + is_admin=is_admin, + ) + session.add(new_user) + # Flush to get ID, but let caller commit + await session.flush() + return new_user + async def register(self, dto: UserRegisterDTO) -> UserDO: - """Register a new user.""" + """Register a new user (Public/Self-Service).""" async with self._session_manager.session() as session: - # Check for bootstrapping condition (no users exist). We allow registration if there are no users. - # even when registration is disabled when bootstrapping. + # Check for bootstrapping condition (no users exist) user_count = (await session.execute(select(func.count(UserDO.id)))).scalar() is_bootstrap = user_count == 0 if not self._config.enable_registration and not is_bootstrap: raise ValueError("Registration is disabled") - if await self.check_user_exists(dto.email): - raise ValueError("User already exists") - - # Hash password before storage. - # Future improvement: Upgrade to stronger hashing (e.g., bcrypt/argon2). - password_md5 = hashlib.md5(dto.password.encode()).hexdigest() - - new_user = UserDO( - username=dto.email, - email=dto.email, - password_md5=password_md5, - display_name=dto.user_name, - is_active=True, - is_admin=is_bootstrap, + new_user = await self._create_user_entry( + session, dto, is_admin=is_bootstrap ) - session.add(new_user) + await session.commit() + await session.refresh(new_user) + return new_user + + async def create_user(self, dto: UserRegisterDTO) -> UserDO: + """Create a new user (Admin/System). Skips registration enabled check.""" + async with self._session_manager.session() as session: + new_user = await self._create_user_entry(session, dto, is_admin=False) await session.commit() await session.refresh(new_user) return new_user diff --git a/tests/server/services/test_user_bootstrap.py b/tests/server/services/test_user_bootstrap.py index 09a96b1..36693c7 100644 --- a/tests/server/services/test_user_bootstrap.py +++ b/tests/server/services/test_user_bootstrap.py @@ -61,3 +61,23 @@ async def test_bootstrap_bypasses_disabled_registration( ) with pytest.raises(ValueError, match="Registration is disabled"): await service.register(dto2) + + +@pytest.mark.asyncio +async def test_admin_create_user_bypass( + session_manager: DatabaseSessionManager, coordination_service: CoordinationService +) -> None: + """Test that create_user allows creating users when disabled.""" + config = AuthConfig(enable_registration=False) + service = UserService(config, coordination_service, session_manager) + + # Bootstrap first + await service.register( + UserRegisterDTO(email="admin@example.com", password="pw", user_name="Admin") + ) + + # Explicitly use create_user (Admin action simulation) + dto2 = UserRegisterDTO(email="new@example.com", password="pw", user_name="New") + user2 = await service.create_user(dto2) + + assert user2.email == "new@example.com" diff --git a/tests/server/test_admin_api.py b/tests/server/test_admin_api.py new file mode 100644 index 0000000..96a5c76 --- /dev/null +++ b/tests/server/test_admin_api.py @@ -0,0 +1,127 @@ +from typing import Any + +import jwt +import pytest +from sqlalchemy import delete + +from supernote.client.client import Client +from supernote.models.user import UserRegisterDTO +from supernote.server.config import ServerConfig +from supernote.server.db.models.user import UserDO +from supernote.server.db.session import DatabaseSessionManager +from supernote.server.services.coordination import CoordinationService +from supernote.server.services.user import JWT_ALGORITHM, UserService + + +@pytest.fixture +def admin_headers(server_config: ServerConfig) -> dict[str, Any]: + """Headers for an ADMIN user.""" + secret = server_config.auth.secret_key + token = jwt.encode({"sub": "admin@example.com"}, secret, algorithm=JWT_ALGORITHM) + return {"x-access-token": token} + + +@pytest.fixture +def user_headers(server_config: ServerConfig) -> dict[str, Any]: + """Headers for a NORMAL user.""" + secret = server_config.auth.secret_key + token = jwt.encode({"sub": "user@example.com"}, secret, algorithm=JWT_ALGORITHM) + return {"x-access-token": token} + + +async def setup_users( + session_manager: DatabaseSessionManager, + coordination_service: CoordinationService, + server_config: ServerConfig, +) -> None: + """Helper to setup database state.""" + async with session_manager.session() as session: + await session.execute(delete(UserDO)) + await session.commit() + + service = UserService(server_config.auth, coordination_service, session_manager) + + # Register Admin (Bootstrapping) + await service.register( + UserRegisterDTO(email="admin@example.com", password="pw", user_name="Admin") + ) + + # Register Normal User (via Admin creation to ensure consistency, + # though strict bootstrapping only allows the FIRST user to be admin, + # so we use create_user for the second one if we wanted, + # but here we just need to ensure the DB state is correct). + # Easier: manually set is_admin=False for the second user if needed, + # but register() naturally makes 2nd user non-admin. + await service.register( + UserRegisterDTO(email="user@example.com", password="pw", user_name="User") + ) + + # Store sessions in coordination service + secret = server_config.auth.secret_key + admin_token = jwt.encode( + {"sub": "admin@example.com"}, secret, algorithm=JWT_ALGORITHM + ) + user_token = jwt.encode( + {"sub": "user@example.com"}, secret, algorithm=JWT_ALGORITHM + ) + + await coordination_service.set_value( + f"session:{admin_token}", "admin@example.com|", ttl=3600 + ) + await coordination_service.set_value( + f"session:{user_token}", "user@example.com|", ttl=3600 + ) + + +async def test_admin_list_users_permission( + client: Client, + session_manager: DatabaseSessionManager, + coordination_service: CoordinationService, + server_config: ServerConfig, + admin_headers: dict[str, str], + user_headers: dict[str, str], +) -> None: + """Test access control for listing users.""" + await setup_users(session_manager, coordination_service, server_config) + + # 1. Admin should succeed + resp = await client.get("/api/admin/users", headers=admin_headers) + assert resp.status == 200 + data = await resp.json() + assert len(data) >= 2 + + # 2. Normal user should fail + resp = await client.get("/api/admin/users", headers=user_headers) + assert resp.status == 403 + + # 3. Anon should fail + resp = await client.get("/api/admin/users") + assert resp.status == 401 + + +async def test_admin_create_user( + client: Client, + session_manager: DatabaseSessionManager, + coordination_service: CoordinationService, + server_config: ServerConfig, + admin_headers: dict[str, str], +) -> None: + """Test admin creating a user.""" + await setup_users(session_manager, coordination_service, server_config) + + new_user = { + "email": "newbie@example.com", + "userName": "Newbie", + "password": "password", + "countryCode": "1", + } + + # Admin creates user + resp = await client.post("/api/admin/users", json=new_user, headers=admin_headers) + assert resp.status == 200 + + # Verify user exists + resp = await client.get("/api/admin/users", headers=admin_headers) + data = await resp.json() + emails = [u["email"] for u in data] + assert "newbie@example.com" in emails