diff --git a/.github/workflows/vul-test.yml b/.github/workflows/vul-test.yml index d108c8e..e9ebafa 100644 --- a/.github/workflows/vul-test.yml +++ b/.github/workflows/vul-test.yml @@ -56,13 +56,13 @@ jobs: - name: Run Vul test run: | - curl --fail --retry-delay 10 --retry 30 --retry-connrefused http://127.0.0.1:8003/api/django/demo/get_open?name=Data cd ${{ github.workspace }}/DockerVulspace + curl --fail --retry-delay 10 --retry 30 --retry-connrefused http://127.0.0.1:8003/api/django/demo/get_open?name=Data + docker-compose logs djangoweb flaskweb docker-compose exec -T djangoweb python -V docker-compose exec -T djangoweb pip list docker-compose exec -T flaskweb python -V docker-compose exec -T flaskweb pip list - docker-compose logs djangoweb flaskweb bash ${{ github.workspace }}/DongTai-agent-python/dongtai_agent_python/tests/vul-test.sh \ django http://127.0.0.1:8003/api/django ${{ github.run_id }} bash ${{ github.workspace }}/DongTai-agent-python/dongtai_agent_python/tests/vul-test.sh \ diff --git a/dongtai_agent_python/api/openapi.py b/dongtai_agent_python/api/openapi.py index 656e69c..a5f1dcb 100644 --- a/dongtai_agent_python/api/openapi.py +++ b/dongtai_agent_python/api/openapi.py @@ -187,6 +187,8 @@ def agent_register(self): logger.error("register get agent id empty") return resp + self.setting.set_shm('agent-' + str(self.agent_id)) + if resp.get('data', {}).get('coreAutoStart', 0) != 1: logger.info("agent is waiting for auditing") self.setting.dt_manual_pause = True diff --git a/dongtai_agent_python/setting/setting.py b/dongtai_agent_python/setting/setting.py index 754bf28..a76880c 100644 --- a/dongtai_agent_python/setting/setting.py +++ b/dongtai_agent_python/setting/setting.py @@ -1,8 +1,13 @@ import os +from multiprocessing import Lock from dongtai_agent_python import version from .config import Config from dongtai_agent_python.utils import Singleton +from dongtai_agent_python.utils.shm import SharedMemoryDict +from dongtai_agent_python.utils.lock import lock + +_lock = Lock() class Setting(Singleton): @@ -13,10 +18,8 @@ def init(self): return self.version = version.__version__ - self.paused = False - self.manual_paused = False self.agent_id = 0 - self.request_seq = 0 + self.shm = None self.auto_create_project = 0 self.use_local_policy = False @@ -38,6 +41,12 @@ def init(self): self.init_os_environ() Setting.loaded = True + def __del__(self): + if self.shm is None: + return False + self.shm.close() + self.shm.unlink() + def set_container(self, container): if container and isinstance(container, dict): self.container = container @@ -76,8 +85,45 @@ def init_os_environ(self): for key in os_env.keys(): self.os_env_list.append(key + '=' + str(os_env[key])) + def set_shm(self, name): + if self.shm is None: + self.shm = SharedMemoryDict('dongtai-shm-python-' + name) + + @property + def paused(self): + if self.shm is None: + return False + return self.shm.get('paused') + + @paused.setter + def paused(self, status): + if self.shm is None: + return + self.shm['paused'] = status + + @property + def manual_paused(self): + if self.shm is None: + return False + return self.shm.get('manual_paused') + + @manual_paused.setter + def manual_paused(self, status): + if self.shm is None: + return + self.shm['manual_paused'] = status + def is_agent_paused(self): - return self.paused and self.manual_paused + return self.paused or self.manual_paused + @property + def request_seq(self): + if self.shm is None: + return 0 + return self.shm.get('request_seq', 0) + + @lock(_lock) def incr_request_seq(self): - self.request_seq = self.request_seq + 1 + if self.shm is None: + return + self.shm['request_seq'] = self.request_seq + 1 diff --git a/dongtai_agent_python/tests/setting/test_setting.py b/dongtai_agent_python/tests/setting/test_setting.py index 7461eb8..8ea8e69 100644 --- a/dongtai_agent_python/tests/setting/test_setting.py +++ b/dongtai_agent_python/tests/setting/test_setting.py @@ -1,4 +1,7 @@ +import multiprocessing +import os import threading +import time import unittest from dongtai_agent_python.setting.setting import Setting @@ -6,19 +9,46 @@ class TestSetting(unittest.TestCase): def test_multithreading(self): - def test(name): + def test_mt(name): st1 = Setting() + st1.shm = None + st1.set_shm("test-setting-001") st1.set_container({'name': name, 'version': '0.1'}) st1.incr_request_seq() thread_num = 5 for i in range(thread_num): - t = threading.Thread(target=test, args=['test' + str(i)]) + t = threading.Thread(target=test_mt, args=['test' + str(i)]) t.start() st = Setting() + st.shm = None + st.set_shm("test-setting-001") + time.sleep(1) self.assertEqual(thread_num, st.request_seq) + def test_multiprocessing(self): + if os.name == "nt": + return + + def test_mp(name): + st1 = Setting() + st1.shm = None + st1.set_shm("test-setting-002") + st1.set_container({'name': name, 'version': '0.1'}) + st1.incr_request_seq() + + process_num = 5 + for i in range(process_num): + p = multiprocessing.Process(target=test_mp, args=('test' + str(i),)) + p.start() + + st = Setting() + st.shm = None + st.set_shm("test-setting-002") + time.sleep(1) + self.assertEqual(process_num, st.request_seq) + if __name__ == '__main__': unittest.main() diff --git a/dongtai_agent_python/utils/lock.py b/dongtai_agent_python/utils/lock.py new file mode 100644 index 0000000..0d20ce9 --- /dev/null +++ b/dongtai_agent_python/utils/lock.py @@ -0,0 +1,14 @@ +from functools import wraps + + +def lock(_lock): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + _lock.acquire() + try: + return func(*args, **kwargs) + finally: + _lock.release() + return wrapper + return decorator diff --git a/dongtai_agent_python/utils/shm/__init__.py b/dongtai_agent_python/utils/shm/__init__.py new file mode 100644 index 0000000..0fdac92 --- /dev/null +++ b/dongtai_agent_python/utils/shm/__init__.py @@ -0,0 +1 @@ +from .shm import SharedMemoryDict diff --git a/dongtai_agent_python/utils/shm/shm.py b/dongtai_agent_python/utils/shm/shm.py new file mode 100644 index 0000000..dad82ba --- /dev/null +++ b/dongtai_agent_python/utils/shm/shm.py @@ -0,0 +1,148 @@ +import logging +import pickle +import sys +from contextlib import contextmanager +from multiprocessing import Lock + +from dongtai_agent_python.utils.lock import lock + +if sys.version_info[:3] < (3, 8): + from shared_memory.shared_memory import SharedMemory +else: + from multiprocessing.shared_memory import SharedMemory + +NULL_BYTE = b"\x00" + +logger = logging.getLogger(__name__) +_lock = Lock() +DEFAULT_OBJ = object() + + +class SharedMemoryDict: + def __init__(self, name, size=1024): + self.name = name + self.mem_block = self.get_or_create(size) + self.init_memory() + + @lock(_lock) + def get_or_create(self, size): + try: + return SharedMemory(name=self.name) + except FileNotFoundError: + return SharedMemory(name=self.name, create=True, size=size) + + def init_memory(self): + memory_is_empty = (bytes(self.mem_block.buf).split(NULL_BYTE, 1)[0] == b'') + if memory_is_empty: + self.save_memory({}) + + def close(self) -> None: + if not hasattr(self, 'mem_block'): + return + self.mem_block.close() + + def unlink(self) -> None: + if not hasattr(self, 'mem_block'): + return + self.mem_block.unlink() + + @lock + def clear(self) -> None: + self.save_memory({}) + + def popitem(self): + with self.modify_db() as db: + return db.popitem() + + def save_memory(self, db) -> None: + data = pickle.dumps(db) + try: + self.mem_block.buf[:len(data)] = data + except ValueError as exc: + logging.error("failed save to memory", exc_info=exc) + + def read_memory(self): + return pickle.loads(self.mem_block.buf.tobytes()) + + @contextmanager + @lock(_lock) + def modify_db(self): + db = self.read_memory() + yield db + self.save_memory(db) + + def __getitem__(self, key: str): + return self.read_memory()[key] + + def __setitem__(self, key: str, value) -> None: + with self.modify_db() as db: + db[key] = value + + def __len__(self) -> int: + return len(self.read_memory()) + + def __delitem__(self, key: str) -> None: + with self.modify_db() as db: + del db[key] + + def __iter__(self): + return iter(self.read_memory()) + + def __reversed__(self): + return reversed(self.read_memory()) + + def __del__(self) -> None: + self.close() + + def __contains__(self, key: str) -> bool: + return key in self.read_memory() + + def __eq__(self, other) -> bool: + return self.read_memory() == other + + def __ne__(self, other) -> bool: + return self.read_memory() != other + + if sys.version_info > (3, 8): + def __or__(self, other): + return self.read_memory() | other + + def __ror__(self, other): + return other | self.read_memory() + + def __ior__(self, other): + with self.modify_db() as db: + db |= other + return db + + def __str__(self): + return str(self.read_memory()) + + def __repr__(self): + return repr(self.read_memory()) + + def get(self, key: str, default=None): + return self.read_memory().get(key, default) + + def keys(self): + return self.read_memory().keys() + + def values(self): + return self.read_memory().values() + + def items(self): + return self.read_memory().items() + + def pop(self, key: str, default=DEFAULT_OBJ): + with self.modify_db() as db: + if default is DEFAULT_OBJ: + return db.pop(key) + return db.pop(key, default) + + def update(self, other=(), **kwargs): + with self.modify_db() as db: + db.update(other, **kwargs) + + def setdefault(self, key: str, default=None): + with self.modify_db() as db: + return db.setdefault(key, default) diff --git a/setup.cfg b/setup.cfg index 6d9b41a..e8d331e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,4 +29,5 @@ install_requires = psutil >= 5.8.0 requests >= 2.25.1 pip >= 19.2.3 + shared-memory38 >= 0.1.2; python_version < '3.8' ; regexploit >= 1.0.0