Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions onebot/plugins/acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import asyncio
import json
from typing import Self

import irc3
Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(self, bot):
self.log.debug("Config: %r", self.config)
if "superadmin" in self.config:
self.log.info("Giving {} all_permissions".format(self.config["superadmin"]))
self.bot.db.set(self.config["superadmin"], permissions=["all_permissions"])
self.bot.db.set(
self.config["superadmin"], permissions=json.dumps(["all_permissions"])
)

@command(permission="admin", show_in_help_list=False)
async def acl(self, mask, target, args) -> None:
Expand Down Expand Up @@ -127,6 +130,7 @@ async def acl(self, mask, target, args) -> None:
current_permissions = self.bot.db.get(args["<id>"], {}).get(
"permissions", []
)
current_permissions = self.bot.deserialize_setting(current_permissions)

if args["add"] and permission not in current_permissions:
current_permissions.append(permission)
Expand All @@ -137,9 +141,7 @@ async def acl(self, mask, target, args) -> None:
assert user is not None
user.set_setting("permissions", current_permissions)
else:
if args["<id>"] not in self.bot.db:
self.bot.db[args["<id>"]] = {}
self.bot.db[args["<id>"]]["permissions"] = current_permissions
self.bot.db.set(args["<id>"], permissions=json.dumps(current_permissions))

self.bot.privmsg(
target,
Expand Down
39 changes: 28 additions & 11 deletions onebot/plugins/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from __future__ import unicode_literals, print_function

import ast
import asyncio
import json
import re
from typing import (
Any,
Expand All @@ -31,6 +31,22 @@
from irc3.utils import IrcString


def deserialize_setting(value: Any) -> Any:
"""Safely deserialize a setting value"""
if not isinstance(value, str):
return value

# Try to parse as JSON
try:
# Handles JSON lists, dicts, booleans (true/false), null, and numbers
return json.loads(value)
except (ValueError, json.JSONDecodeError):
pass

# Return as a plain string
return value


class User(object):
"""User object"""

Expand Down Expand Up @@ -75,7 +91,12 @@ async def wrapper() -> None:

def set_setting(self, setting: str, value: Any) -> None:
"""Set a specified setting to a value"""
print("Trying to set %s to %s" % (setting, value))

# Serialize non-string types to JSON for consistent storage across backends
if not isinstance(value, str):
if isinstance(value, set):
value = list(value)
value = json.dumps(value)

async def wrapper():
id_ = await self.id()
Expand All @@ -91,15 +112,7 @@ async def get_settings(self) -> Dict[str, Any]:
async def get_setting(self, setting, default=None) -> Any:
"""Gets a setting for the users. Can be any type."""
settings = await self.get_settings()
result = settings.get(setting, default)
if isinstance(result, str):
try:
parsed = ast.literal_eval(result)
return parsed
except (ValueError, SyntaxError):
pass

return result
return deserialize_setting(settings.get(setting, default))

def join(self, channel) -> None:
"""Register that the user joined a channel"""
Expand Down Expand Up @@ -168,6 +181,10 @@ def get_user(self, nick: str):
self.log.warning("Couldn't find %s!", nick)
return user

@irc3.extend
def deserialize_setting(self, value: Any) -> Any:
return deserialize_setting(value)

@irc3.extend
def redact_nicks(self, message: str, target: Optional[str] = None) -> str:
"""Redacts all known nicks in the message.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_plugin_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def tearDown(self):
def test_command_allowed(self):
async def wrap():
self.bot.dispatch(":im!the@boss JOIN #chan")
self.bot.db["the@boss"] = {"permissions": {"test"}}
self.bot.db["the@boss"] = {"permissions": '["test"]'}
self.bot.dispatch(":im!the@boss PRIVMSG #chan :!cmd")
await asyncio.sleep(0.001)

Expand All @@ -75,7 +75,7 @@ async def wrap():
def test_command_ignored(self):
async def wrap():
self.bot.dispatch(":Groxxxy!stupid@idiot JOIN #chan")
self.bot.db["stupid@idiot"] = {"permissions": {"ignore"}}
self.bot.db["stupid@idiot"] = {"permissions": '["ignore"]'}
self.bot.dispatch(":Groxxxy!stupid@idiot PRIVMSG #chan :!cmd2")
await asyncio.sleep(0.001)

Expand Down Expand Up @@ -116,7 +116,7 @@ async def wrap():
await asyncio.sleep(0.001)

self.bot.loop.run_until_complete(wrap())
self.assertEqual(self.bot.db["foo@host"].get("permissions"), ["admin"])
self.assertEqual(self.bot.db["foo@host"].get("permissions"), '["admin"]')
self.assertSent(["PRIVMSG #chan :Updated permissions for bar"])

def test_add_unknown_user(self):
Expand All @@ -136,7 +136,7 @@ async def wrap():
await asyncio.sleep(0.001)

self.bot.loop.run_until_complete(wrap())
self.assertEqual(self.bot.db["bak"].get("permissions"), ["admin"])
self.assertEqual(self.bot.db["bak"].get("permissions"), '["admin"]')
self.assertSent(["PRIVMSG #chan :Updated permissions for bak"])

def test_invalid_permission(self):
Expand All @@ -154,14 +154,14 @@ async def wrap():
)

def test_remove_acl(self):
self.bot.db["foo@host"] = {"permissions": {"admin"}}
self.bot.db["foo@host"] = {"permissions": '["admin"]'}

# sanity check
self.assertEqual(self.bot.db["foo@host"].get("permissions"), {"admin"})
self.assertEqual(self.bot.db["foo@host"].get("permissions"), '["admin"]')

async def wrap():
self.bot.dispatch(":root@localhost PRIVMSG #chan :!acl remove bar admin")
await asyncio.sleep(0.001)

self.bot.loop.run_until_complete(wrap())
self.assertEqual(self.bot.db["foo@host"].get("permissions"), set())
self.assertEqual(self.bot.db["foo@host"].get("permissions"), "[]")