diff --git a/README.md b/README.md index a05c901..b5c516b 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ A powerful local AI workflow system with multi-model support and visual workflow ### Requirements - Python 3.9+ - Ollama (for local models) - [Download here](https://ollama.com/download) +- (Optional) SQL dependencies for RDS memory: `sqlalchemy` and `pymysql` for MySQL support ### Installation ```bash @@ -37,6 +38,11 @@ cd vertex pip install -e . ``` +```bash +# Optional SQL dependencies for RDS memory +pip install sqlalchemy pymysql +``` + ### Configuration ```bash # Quick setup - Initialize configuration diff --git a/README_EN.md b/README_EN.md index a05c901..b5c516b 100644 --- a/README_EN.md +++ b/README_EN.md @@ -25,6 +25,7 @@ A powerful local AI workflow system with multi-model support and visual workflow ### Requirements - Python 3.9+ - Ollama (for local models) - [Download here](https://ollama.com/download) +- (Optional) SQL dependencies for RDS memory: `sqlalchemy` and `pymysql` for MySQL support ### Installation ```bash @@ -37,6 +38,11 @@ cd vertex pip install -e . ``` +```bash +# Optional SQL dependencies for RDS memory +pip install sqlalchemy pymysql +``` + ### Configuration ```bash # Quick setup - Initialize configuration diff --git a/README_ZH.md b/README_ZH.md index 02a021d..3e8eb86 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -25,6 +25,7 @@ ### 环境要求 - Python 3.9+ - Ollama(本地模型)- [下载地址](https://ollama.com/download) +- (可选)RDS 内存后端需要的 SQL 依赖:`sqlalchemy`,若使用 MySQL 还需 `pymysql` ### 安装方式 @@ -58,6 +59,11 @@ cd vertex pip install -e . ``` +```bash +# 可选:安装 RDS 内存后端所需依赖 +pip install sqlalchemy pymysql +``` + ### 配置 ```bash # 快速设置 - 初始化配置 diff --git a/pyproject.toml b/pyproject.toml index 7c3deda..834a9c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,13 @@ cloud-vector = [ "dashvector>=1.0.19", ] +# Memory backends +memory = [ + "redis>=5.0.0", + "sqlalchemy>=2.0.0", + "pymysql>=1.1.0", +] + # 桌面端应用(可选) desktop = [ "pywebview>=5.4", @@ -113,6 +120,9 @@ all = [ "dashvector>=1.0.19", "pywebview>=5.4", "requests>=2.28.2", + "redis>=5.0.0", + "sqlalchemy>=2.0.0", + "pymysql>=1.1.0", ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index 1080461..9e07205 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,6 +51,11 @@ nest-asyncio>=1.6.0 # 进程管理 psutil>=5.9.0 +# 缓存和持久化存储 +redis>=5.0.0 +sqlalchemy>=2.0.0 +pymysql>=1.1.0 + # MCP (Model Context Protocol) 支持 aiohttp>=3.8.0 # 已包含在上面的网络依赖中 diff --git a/setup.py b/setup.py index 3d724de..42a2208 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,12 @@ "desktop": [ "pywebview>=5.4", ], + # 缓存和持久化存储 + "memory": [ + "redis>=5.0.0", + "sqlalchemy>=2.0.0", + "pymysql>=1.1.0", + ], # 完整功能(包含所有可选依赖) "all": [ "sentence-transformers>=2.2.0", @@ -84,6 +90,9 @@ "dashvector>=1.0.19", "pywebview>=5.4", "requests>=2.28.2", + "redis>=5.0.0", + "sqlalchemy>=2.0.0", + "pymysql>=1.1.0", ], }, entry_points={ diff --git a/vertex_flow/memory/__init__.py b/vertex_flow/memory/__init__.py index 54a98fa..629f673 100644 --- a/vertex_flow/memory/__init__.py +++ b/vertex_flow/memory/__init__.py @@ -11,7 +11,20 @@ from .factory import MemoryFactory, create_memory, create_memory_from_config from .file_store import FileMemory +from .hybrid_store import HybridMemory from .inmem_store import InnerMemory from .memory import Memory +from .redis_store import RedisMemory +from .rds_store import RDSMemory -__all__ = ["Memory", "InnerMemory", "FileMemory", "MemoryFactory", "create_memory", "create_memory_from_config"] +__all__ = [ + "Memory", + "InnerMemory", + "FileMemory", + "HybridMemory", + "RedisMemory", + "RDSMemory", + "MemoryFactory", + "create_memory", + "create_memory_from_config", +] diff --git a/vertex_flow/memory/factory.py b/vertex_flow/memory/factory.py index 73cc85e..1ebc1b6 100644 --- a/vertex_flow/memory/factory.py +++ b/vertex_flow/memory/factory.py @@ -3,8 +3,11 @@ from typing import Any, Dict, Optional from .file_store import FileMemory +from .hybrid_store import HybridMemory from .inmem_store import InnerMemory from .memory import Memory +from .redis_store import RedisMemory +from .rds_store import RDSMemory class MemoryFactory: @@ -16,6 +19,9 @@ class MemoryFactory: "memory": InnerMemory, # alias for backward compatibility "inmem": InnerMemory, # alias for backward compatibility "file": FileMemory, + "redis": RedisMemory, + "rds": RDSMemory, + "hybrid": HybridMemory, } @classmethod @@ -122,6 +128,17 @@ def get_default_config(cls, memory_type: str = "inner") -> Dict[str, Any]: return {"type": "inner", "hist_maxlen": 200, "cleanup_interval_sec": 300} elif memory_type == "file": return {"type": "file", "storage_dir": "./memory_data", "hist_maxlen": 200} + elif memory_type == "redis": + return {"type": "redis", "url": "redis://localhost:6379/0", "hist_maxlen": 200} + elif memory_type == "rds": + return {"type": "rds", "db_url": "sqlite:///:memory:", "hist_maxlen": 200} + elif memory_type == "hybrid": + return { + "type": "hybrid", + "redis_url": "redis://localhost:6379/0", + "db_url": "sqlite:///:memory:", + "hist_maxlen": 200, + } else: available_types = ", ".join(cls._memory_types.keys()) raise ValueError(f"Unsupported memory type: {memory_type}. " f"Available types: {available_types}") diff --git a/vertex_flow/memory/hybrid_store.py b/vertex_flow/memory/hybrid_store.py new file mode 100644 index 0000000..a58a363 --- /dev/null +++ b/vertex_flow/memory/hybrid_store.py @@ -0,0 +1,87 @@ +"""Hybrid memory store combining Redis for caching and RDS for persistence.""" + +from __future__ import annotations + +from typing import Any, Optional + +from .memory import Memory +from .rds_store import RDSMemory +from .redis_store import RedisMemory + + +class HybridMemory(Memory): + """Memory implementation using Redis as cache and RDS as persistent storage.""" + + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + db_url: str = "sqlite:///:memory:", + hist_maxlen: int = 200, + prefix: str = "vf:", + redis_client=None, + ) -> None: + self._redis = RedisMemory(url=redis_url, hist_maxlen=hist_maxlen, prefix=prefix, client=redis_client) + self._rds = RDSMemory(db_url=db_url, hist_maxlen=hist_maxlen) + self._hist_maxlen = hist_maxlen + + # Deduplication ----------------------------------------------------------------- + def seen(self, user_id: str, key: str, ttl_sec: int = 3600) -> bool: + result = self._rds.seen(user_id, key, ttl_sec) + self._redis.seen(user_id, key, ttl_sec) + return result + + # History ---------------------------------------------------------------------- + def append_history(self, user_id: str, role: str, mtype: str, content: dict, maxlen: int = 200) -> None: + self._rds.append_history(user_id, role, mtype, content, maxlen) + self._redis.append_history(user_id, role, mtype, content, maxlen) + + def recent_history(self, user_id: str, n: int = 20) -> list[dict]: + history = self._redis.recent_history(user_id, n) + if history: + return history + history = self._rds.recent_history(user_id, n) + for msg in reversed(history): + self._redis.append_history(user_id, msg["role"], msg["type"], msg["content"], self._hist_maxlen) + return history + + # Context ---------------------------------------------------------------------- + def ctx_set(self, user_id: str, key: str, value: Any, ttl_sec: Optional[int] = None) -> None: + self._rds.ctx_set(user_id, key, value, ttl_sec) + self._redis.ctx_set(user_id, key, value, ttl_sec) + + def ctx_get(self, user_id: str, key: str) -> Optional[Any]: + value = self._redis.ctx_get(user_id, key) + if value is not None: + return value + value = self._rds.ctx_get(user_id, key) + if value is not None: + self._redis.ctx_set(user_id, key, value) + return value + + def ctx_del(self, user_id: str, key: str) -> None: + self._rds.ctx_del(user_id, key) + self._redis.ctx_del(user_id, key) + + # Ephemeral -------------------------------------------------------------------- + def set_ephemeral(self, user_id: str, key: str, value: Any, ttl_sec: int = 1800) -> None: + self._rds.set_ephemeral(user_id, key, value, ttl_sec) + self._redis.set_ephemeral(user_id, key, value, ttl_sec) + + def get_ephemeral(self, user_id: str, key: str) -> Optional[Any]: + value = self._redis.get_ephemeral(user_id, key) + if value is not None: + return value + return self._rds.get_ephemeral(user_id, key) + + def del_ephemeral(self, user_id: str, key: str) -> None: + self._rds.del_ephemeral(user_id, key) + self._redis.del_ephemeral(user_id, key) + + # Rate limiting ---------------------------------------------------------------- + def incr_rate(self, user_id: str, bucket: str, ttl_sec: int = 60) -> int: + count = self._rds.incr_rate(user_id, bucket, ttl_sec) + redis_key = self._redis._rate_key(user_id, bucket) + self._redis._client.set(redis_key, str(count)) + if ttl_sec > 0: + self._redis._client.expire(redis_key, ttl_sec) + return count diff --git a/vertex_flow/memory/rds_store.py b/vertex_flow/memory/rds_store.py new file mode 100644 index 0000000..41bdb5d --- /dev/null +++ b/vertex_flow/memory/rds_store.py @@ -0,0 +1,242 @@ +"""Relational database backed memory store using SQLAlchemy.""" + +from __future__ import annotations + +import json +import threading +import time +from typing import Any, Optional +from urllib.parse import urlparse + +from .memory import Memory + +try: # pragma: no cover - optional dependency + import sqlalchemy as sa +except Exception: # pragma: no cover + sa = None + + +class RDSMemory(Memory): + """Relational database backed memory store supporting SQLite and MySQL.""" + + def __init__( + self, + db_url: Optional[str] = None, + db_path: Optional[str] = None, + hist_maxlen: int = 200, + ) -> None: + if db_url is None: + if db_path is not None: + db_url = f"sqlite:///{db_path}" + else: + db_url = "sqlite:///:memory:" + self._db_url = db_url + self._hist_maxlen = hist_maxlen + self._lock = threading.RLock() + + if sa is None: # pragma: no cover - handled in tests + raise RuntimeError("sqlalchemy is required for RDSMemory") + + parsed = urlparse(db_url) + scheme = (parsed.scheme or "sqlite").lower() + if scheme.startswith("mysql"): + try: # pragma: no cover - optional dependency + import pymysql # noqa: F401 + except Exception as exc: + raise RuntimeError("pymysql is required for MySQL support") from exc + if "+" not in scheme: + db_url = "mysql+pymysql" + db_url[len("mysql") :] + elif not scheme.startswith("sqlite"): + raise ValueError(f"Unsupported RDS scheme: {scheme}") + + self._engine = sa.create_engine(db_url, future=True) + + self._meta = sa.MetaData() + self._dedup = sa.Table( + "dedup", + self._meta, + sa.Column("user_id", sa.String(255), primary_key=True), + sa.Column("key", sa.String(255), primary_key=True), + sa.Column("expires_at", sa.Float, nullable=True), + ) + self._history = sa.Table( + "history", + self._meta, + sa.Column("id", sa.BigInteger, primary_key=True, autoincrement=True), + sa.Column("user_id", sa.String(255)), + sa.Column("message", sa.Text), + sa.Column("timestamp", sa.Float), + ) + self._ctx = sa.Table( + "ctx", + self._meta, + sa.Column("user_id", sa.String(255), primary_key=True), + sa.Column("key", sa.String(255), primary_key=True), + sa.Column("value", sa.Text), + sa.Column("expires_at", sa.Float, nullable=True), + ) + self._ephemeral = sa.Table( + "ephemeral", + self._meta, + sa.Column("user_id", sa.String(255), primary_key=True), + sa.Column("key", sa.String(255), primary_key=True), + sa.Column("value", sa.Text), + sa.Column("expires_at", sa.Float, nullable=True), + ) + self._rate = sa.Table( + "rate", + self._meta, + sa.Column("user_id", sa.String(255), primary_key=True), + sa.Column("bucket", sa.String(255), primary_key=True), + sa.Column("value", sa.Integer), + sa.Column("expires_at", sa.Float, nullable=True), + ) + self._meta.create_all(self._engine) + + def _is_expired(self, expires_at: Optional[float]) -> bool: + return expires_at is not None and time.time() > expires_at + + # Deduplication ----------------------------------------------------------------- + def seen(self, user_id: str, key: str, ttl_sec: int = 3600) -> bool: + with self._lock, self._engine.begin() as conn: + stmt = sa.select(self._dedup.c.expires_at).where(self._dedup.c.user_id == user_id, self._dedup.c.key == key) + row = conn.execute(stmt).fetchone() + if row and not self._is_expired(row.expires_at): + return True + expires_at = time.time() + ttl_sec if ttl_sec > 0 else None + conn.execute(sa.delete(self._dedup).where(self._dedup.c.user_id == user_id, self._dedup.c.key == key)) + conn.execute(self._dedup.insert().values(user_id=user_id, key=key, expires_at=expires_at)) + return False + + # History ---------------------------------------------------------------------- + def append_history(self, user_id: str, role: str, mtype: str, content: dict, maxlen: int = 200) -> None: + with self._lock, self._engine.begin() as conn: + ts = time.time() + message = json.dumps({"role": role, "type": mtype, "content": content, "timestamp": ts}) + conn.execute(self._history.insert().values(user_id=user_id, message=message, timestamp=ts)) + sub = ( + sa.select(self._history.c.id) + .where(self._history.c.user_id == user_id) + .order_by(self._history.c.timestamp.desc()) + .limit(maxlen) + .subquery() + ) + conn.execute( + sa.delete(self._history).where( + self._history.c.user_id == user_id, + self._history.c.id.notin_(sa.select(sub.c.id)), + ) + ) + + def recent_history(self, user_id: str, n: int = 20) -> list[dict]: + with self._lock, self._engine.begin() as conn: + stmt = ( + sa.select(self._history.c.message) + .where(self._history.c.user_id == user_id) + .order_by(self._history.c.timestamp.desc()) + .limit(n) + ) + rows = conn.execute(stmt).fetchall() + return [json.loads(row.message) for row in rows] + + # Context ---------------------------------------------------------------------- + def ctx_set(self, user_id: str, key: str, value: Any, ttl_sec: Optional[int] = None) -> None: + with self._lock, self._engine.begin() as conn: + expires_at = time.time() + ttl_sec if ttl_sec and ttl_sec > 0 else None + conn.execute(sa.delete(self._ctx).where(self._ctx.c.user_id == user_id, self._ctx.c.key == key)) + conn.execute( + self._ctx.insert().values( + user_id=user_id, + key=key, + value=json.dumps(value, ensure_ascii=False), + expires_at=expires_at, + ) + ) + + def ctx_get(self, user_id: str, key: str) -> Optional[Any]: + with self._lock, self._engine.begin() as conn: + stmt = sa.select(self._ctx.c.value, self._ctx.c.expires_at).where( + self._ctx.c.user_id == user_id, self._ctx.c.key == key + ) + row = conn.execute(stmt).fetchone() + if not row: + return None + if self._is_expired(row.expires_at): + conn.execute(sa.delete(self._ctx).where(self._ctx.c.user_id == user_id, self._ctx.c.key == key)) + return None + return json.loads(row.value) + + def ctx_del(self, user_id: str, key: str) -> None: + with self._lock, self._engine.begin() as conn: + conn.execute(sa.delete(self._ctx).where(self._ctx.c.user_id == user_id, self._ctx.c.key == key)) + + # Ephemeral -------------------------------------------------------------------- + def set_ephemeral(self, user_id: str, key: str, value: Any, ttl_sec: int = 1800) -> None: + with self._lock, self._engine.begin() as conn: + expires_at = time.time() + ttl_sec if ttl_sec > 0 else None + conn.execute( + sa.delete(self._ephemeral).where( + self._ephemeral.c.user_id == user_id, + self._ephemeral.c.key == key, + ) + ) + conn.execute( + self._ephemeral.insert().values( + user_id=user_id, + key=key, + value=json.dumps(value, ensure_ascii=False), + expires_at=expires_at, + ) + ) + + def get_ephemeral(self, user_id: str, key: str) -> Optional[Any]: + with self._lock, self._engine.begin() as conn: + stmt = sa.select(self._ephemeral.c.value, self._ephemeral.c.expires_at).where( + self._ephemeral.c.user_id == user_id, + self._ephemeral.c.key == key, + ) + row = conn.execute(stmt).fetchone() + if not row: + return None + if self._is_expired(row.expires_at): + conn.execute( + sa.delete(self._ephemeral).where( + self._ephemeral.c.user_id == user_id, + self._ephemeral.c.key == key, + ) + ) + return None + return json.loads(row.value) + + def del_ephemeral(self, user_id: str, key: str) -> None: + with self._lock, self._engine.begin() as conn: + conn.execute( + sa.delete(self._ephemeral).where( + self._ephemeral.c.user_id == user_id, + self._ephemeral.c.key == key, + ) + ) + + # Rate limiting ---------------------------------------------------------------- + def incr_rate(self, user_id: str, bucket: str, ttl_sec: int = 60) -> int: + with self._lock, self._engine.begin() as conn: + stmt = sa.select(self._rate.c.value, self._rate.c.expires_at).where( + self._rate.c.user_id == user_id, self._rate.c.bucket == bucket + ) + row = conn.execute(stmt).fetchone() + now = time.time() + expires_at = now + ttl_sec if ttl_sec > 0 else None + if not row or self._is_expired(row.expires_at): + value = 1 + else: + value = int(row.value) + 1 + conn.execute(sa.delete(self._rate).where(self._rate.c.user_id == user_id, self._rate.c.bucket == bucket)) + conn.execute( + self._rate.insert().values( + user_id=user_id, + bucket=bucket, + value=value, + expires_at=expires_at, + ) + ) + return int(value) diff --git a/vertex_flow/memory/redis_store.py b/vertex_flow/memory/redis_store.py new file mode 100644 index 0000000..3f52b76 --- /dev/null +++ b/vertex_flow/memory/redis_store.py @@ -0,0 +1,115 @@ +"""Redis-based implementation of Memory interface.""" + +from __future__ import annotations + +import json +from typing import Any, Optional + +try: # pragma: no cover - optional dependency + import redis +except Exception: # pragma: no cover + redis = None + +from .memory import Memory + + +class RedisMemory(Memory): + """Redis-based memory store. + + Args: + url: Redis connection URL. Defaults to ``redis://localhost:6379/0``. + hist_maxlen: Default maximum history length. + prefix: Key prefix for namespacing. + client: Optional pre-initialized ``redis.Redis`` client. + """ + + def __init__( + self, + url: str = "redis://localhost:6379/0", + hist_maxlen: int = 200, + prefix: str = "vf:", + client: Optional[redis.Redis] = None, + ) -> None: + if client is not None: + self._client = client + else: + if redis is None: + raise ImportError("redis package is required") + self._client = redis.Redis.from_url(url, decode_responses=True) + self._hist_maxlen = hist_maxlen + self._prefix = prefix + + # Key helpers ----------------------------------------------------------------- + def _hist_key(self, user_id: str) -> str: + return f"{self._prefix}hist:{user_id}" + + def _ctx_key(self, user_id: str, key: str) -> str: + return f"{self._prefix}ctx:{user_id}:{key}" + + def _ephemeral_key(self, user_id: str, key: str) -> str: + return f"{self._prefix}ephemeral:{user_id}:{key}" + + def _dedup_key(self, user_id: str, key: str) -> str: + return f"{self._prefix}dedup:{user_id}:{key}" + + def _rate_key(self, user_id: str, bucket: str) -> str: + return f"{self._prefix}rate:{user_id}:{bucket}" + + # Memory API ------------------------------------------------------------------- + def seen(self, user_id: str, key: str, ttl_sec: int = 3600) -> bool: + redis_key = self._dedup_key(user_id, key) + result = self._client.set(redis_key, "1", nx=True, ex=ttl_sec if ttl_sec > 0 else None) + return result is None + + def append_history(self, user_id: str, role: str, mtype: str, content: dict, maxlen: int = 200) -> None: + message = json.dumps({"role": role, "type": mtype, "content": content}) + key = self._hist_key(user_id) + pipe = self._client.pipeline() + pipe.lpush(key, message) + pipe.ltrim(key, 0, maxlen - 1) + pipe.execute() + + def recent_history(self, user_id: str, n: int = 20) -> list[dict]: + key = self._hist_key(user_id) + messages = self._client.lrange(key, 0, n - 1) + return [json.loads(m) for m in messages] + + def ctx_set(self, user_id: str, key: str, value: Any, ttl_sec: Optional[int] = None) -> None: + redis_key = self._ctx_key(user_id, key) + value_str = json.dumps(value, ensure_ascii=False) + if ttl_sec and ttl_sec > 0: + self._client.set(redis_key, value_str, ex=ttl_sec) + else: + self._client.set(redis_key, value_str) + + def ctx_get(self, user_id: str, key: str) -> Optional[Any]: + value = self._client.get(self._ctx_key(user_id, key)) + if value is None: + return None + return json.loads(value) + + def ctx_del(self, user_id: str, key: str) -> None: + self._client.delete(self._ctx_key(user_id, key)) + + def set_ephemeral(self, user_id: str, key: str, value: Any, ttl_sec: int = 1800) -> None: + redis_key = self._ephemeral_key(user_id, key) + value_str = json.dumps(value, ensure_ascii=False) + self._client.set(redis_key, value_str, ex=ttl_sec if ttl_sec > 0 else None) + + def get_ephemeral(self, user_id: str, key: str) -> Optional[Any]: + value = self._client.get(self._ephemeral_key(user_id, key)) + if value is None: + return None + return json.loads(value) + + def del_ephemeral(self, user_id: str, key: str) -> None: + self._client.delete(self._ephemeral_key(user_id, key)) + + def incr_rate(self, user_id: str, bucket: str, ttl_sec: int = 60) -> int: + key = self._rate_key(user_id, bucket) + pipe = self._client.pipeline() + pipe.incr(key) + if ttl_sec > 0: + pipe.expire(key, ttl_sec) + count, _ = pipe.execute() + return int(count) diff --git a/vertex_flow/tests/test_memory_factory.py b/vertex_flow/tests/test_memory_factory.py index aa9de81..774cbec 100644 --- a/vertex_flow/tests/test_memory_factory.py +++ b/vertex_flow/tests/test_memory_factory.py @@ -1,12 +1,25 @@ """Tests for Memory Factory.""" +import importlib import shutil import tempfile from pathlib import Path import pytest -from vertex_flow.memory import FileMemory, InnerMemory, Memory, MemoryFactory, create_memory, create_memory_from_config +from vertex_flow.memory import ( + FileMemory, + HybridMemory, + InnerMemory, + Memory, + MemoryFactory, + RDSMemory, + RedisMemory, + create_memory, + create_memory_from_config, +) + +has_sqlalchemy = importlib.util.find_spec("sqlalchemy") is not None class TestMemoryFactory: @@ -36,6 +49,29 @@ def test_create_file_memory(self): assert isinstance(memory, FileMemory) assert memory._storage_dir == Path(temp_dir) + def test_create_redis_memory(self): + """Test creating RedisMemory through factory.""" + + class DummyRedis: + pass + + memory = MemoryFactory.create_memory("redis", client=DummyRedis()) + assert isinstance(memory, RedisMemory) + + @pytest.mark.skipif(not has_sqlalchemy, reason="sqlalchemy required") + def test_create_rds_memory(self): + """Test creating RDSMemory through factory.""" + memory = MemoryFactory.create_memory("rds", db_url="sqlite:///:memory:") + assert isinstance(memory, RDSMemory) + + @pytest.mark.skipif(not has_sqlalchemy, reason="sqlalchemy required") + def test_create_hybrid_memory(self): + class DummyRedis: + pass + + memory = MemoryFactory.create_memory("hybrid", redis_client=DummyRedis(), db_url="sqlite:///:memory:") + assert isinstance(memory, HybridMemory) + def test_create_memory_invalid_type(self): """Test creating memory with invalid type.""" with pytest.raises(ValueError, match="Unsupported memory type: invalid"): @@ -97,6 +133,9 @@ def test_get_available_types(self): assert "memory" in types assert "inmem" in types assert "file" in types + assert "redis" in types + assert "rds" in types + assert "hybrid" in types assert isinstance(types, list) def test_get_default_config_inner(self): @@ -111,6 +150,30 @@ def test_get_default_config_file(self): expected = {"type": "file", "storage_dir": "./memory_data", "hist_maxlen": 200} assert config == expected + def test_get_default_config_redis(self): + """Test getting default config for redis memory.""" + config = MemoryFactory.get_default_config("redis") + expected = {"type": "redis", "url": "redis://localhost:6379/0", "hist_maxlen": 200} + assert config == expected + + @pytest.mark.skipif(not has_sqlalchemy, reason="sqlalchemy required") + def test_get_default_config_rds(self): + """Test getting default config for rds memory.""" + config = MemoryFactory.get_default_config("rds") + expected = {"type": "rds", "db_url": "sqlite:///:memory:", "hist_maxlen": 200} + assert config == expected + + @pytest.mark.skipif(not has_sqlalchemy, reason="sqlalchemy required") + def test_get_default_config_hybrid(self): + config = MemoryFactory.get_default_config("hybrid") + expected = { + "type": "hybrid", + "redis_url": "redis://localhost:6379/0", + "db_url": "sqlite:///:memory:", + "hist_maxlen": 200, + } + assert config == expected + def test_get_default_config_invalid_type(self): """Test getting default config for invalid type.""" with pytest.raises(ValueError, match="Unsupported memory type: invalid"): @@ -139,6 +202,21 @@ def test_memory_interface_compliance(self): file_memory = MemoryFactory.create_memory("file", storage_dir=temp_dir) assert isinstance(file_memory, Memory) + class DummyRedis: + pass + + redis_memory = MemoryFactory.create_memory("redis", client=DummyRedis()) + assert isinstance(redis_memory, Memory) + + if has_sqlalchemy: + rds_memory = MemoryFactory.create_memory("rds", db_url="sqlite:///:memory:") + assert isinstance(rds_memory, Memory) + + hybrid_memory = MemoryFactory.create_memory( + "hybrid", redis_client=DummyRedis(), db_url="sqlite:///:memory:" + ) + assert isinstance(hybrid_memory, Memory) + def test_factory_integration(self): """Test end-to-end factory usage.""" # Test with inner memory diff --git a/vertex_flow/tests/test_memory_hybrid.py b/vertex_flow/tests/test_memory_hybrid.py new file mode 100644 index 0000000..e6cdbff --- /dev/null +++ b/vertex_flow/tests/test_memory_hybrid.py @@ -0,0 +1,138 @@ +"""Tests for HybridMemory combining Redis and RDS.""" + +import time + +import pytest + +pytest.importorskip("sqlalchemy") + +from vertex_flow.memory import HybridMemory + + +class DummyPipeline: + def __init__(self, client): + self._client = client + self._commands = [] + + def lpush(self, *args): + self._commands.append(("lpush", args)) + return self + + def ltrim(self, *args): + self._commands.append(("ltrim", args)) + return self + + def incr(self, *args): + self._commands.append(("incr", args)) + return self + + def expire(self, *args): + self._commands.append(("expire", args)) + return self + + def execute(self): + results = [] + for cmd, args in self._commands: + results.append(getattr(self._client, cmd)(*args)) + self._commands.clear() + return results + + +class DummyRedis: + def __init__(self): + self._store = {} + self._lists = {} + + def _check_expired(self, key): + if key in self._store: + value, exp = self._store[key] + if exp is not None and time.time() > exp: + del self._store[key] + + def set(self, key, value, nx=False, ex=None): + self._check_expired(key) + if nx and key in self._store: + return None + expires_at = time.time() + ex if ex else None + self._store[key] = (value, expires_at) + return True + + def get(self, key): + self._check_expired(key) + if key not in self._store: + return None + return self._store[key][0] + + def delete(self, key): + self._store.pop(key, None) + + def lpush(self, key, value): + self._lists.setdefault(key, []) + self._lists[key].insert(0, value) + + def ltrim(self, key, start, end): + self._lists.setdefault(key, []) + self._lists[key] = self._lists[key][start : end + 1] + + def lrange(self, key, start, end): + lst = self._lists.get(key, []) + if end == -1: + end = len(lst) - 1 + return lst[start : end + 1] + + def pipeline(self): + return DummyPipeline(self) + + def incr(self, key): + self._check_expired(key) + value = int(self._store.get(key, ("0", None))[0]) + 1 + _, exp = self._store.get(key, (None, None)) + self._store[key] = (str(value), exp) + return value + + def expire(self, key, ttl): + if key in self._store: + value, _ = self._store[key] + self._store[key] = (value, time.time() + ttl) + return True + return False + + +class TestHybridMemory: + def setup_method(self): + self.redis = DummyRedis() + self.memory = HybridMemory(redis_client=self.redis, db_url="sqlite:///:memory:", hist_maxlen=5) + + def test_seen_deduplication_eventual(self): + assert self.memory.seen("u", "k", ttl_sec=1) is False + # simulate redis loss + self.redis._store.clear() + assert self.memory.seen("u", "k", ttl_sec=1) is True + + def test_history_fallback(self): + uid = "u" + for i in range(3): + self.memory.append_history(uid, "user", "text", {"text": str(i)}, maxlen=5) + # drop redis history + self.redis._lists.clear() + history = self.memory.recent_history(uid, n=5) + assert len(history) == 3 + # redis should be repopulated + assert self.redis.lrange(self.memory._redis._hist_key(uid), 0, -1) + + def test_ctx_eventual(self): + self.memory.ctx_set("u", "k", {"v": 1}) + self.redis._store.clear() + assert self.memory.ctx_get("u", "k") == {"v": 1} + + def test_ephemeral_operations(self): + self.memory.set_ephemeral("u", "k", 1, ttl_sec=1) + self.redis._store.clear() + assert self.memory.get_ephemeral("u", "k") == 1 + time.sleep(1.1) + assert self.memory.get_ephemeral("u", "k") is None + + def test_rate_counter_eventual(self): + assert self.memory.incr_rate("u", "b", ttl_sec=10) == 1 + self.redis._store.clear() + assert self.memory.incr_rate("u", "b", ttl_sec=10) == 2 diff --git a/vertex_flow/tests/test_memory_rds.py b/vertex_flow/tests/test_memory_rds.py new file mode 100644 index 0000000..2bcb671 --- /dev/null +++ b/vertex_flow/tests/test_memory_rds.py @@ -0,0 +1,57 @@ +"""Tests for RDSMemory implementation.""" + +import time + +import pytest + +pytest.importorskip("sqlalchemy") + +from vertex_flow.memory import RDSMemory + + +class TestRDSMemory: + """Basic test cases for RDSMemory.""" + + def setup_method(self): + self.memory = RDSMemory(db_url="sqlite:///:memory:", hist_maxlen=5) + + def test_seen_deduplication(self): + user_id = "u" + key = "k" + assert self.memory.seen(user_id, key, ttl_sec=1) is False + assert self.memory.seen(user_id, key, ttl_sec=1) is True + time.sleep(1.1) + assert self.memory.seen(user_id, key, ttl_sec=1) is False + + def test_append_history_maxlen(self): + uid = "u" + for i in range(10): + self.memory.append_history(uid, "user", "text", {"text": str(i)}, maxlen=5) + history = self.memory.recent_history(uid, n=10) + assert len(history) == 5 + assert history[0]["content"]["text"] == "9" + assert history[4]["content"]["text"] == "5" + + def test_ctx_operations(self): + self.memory.ctx_set("u", "k", {"v": 1}) + assert self.memory.ctx_get("u", "k") == {"v": 1} + self.memory.ctx_del("u", "k") + assert self.memory.ctx_get("u", "k") is None + + def test_ephemeral_operations(self): + self.memory.set_ephemeral("u", "k", 1, ttl_sec=1) + assert self.memory.get_ephemeral("u", "k") == 1 + time.sleep(1.1) + assert self.memory.get_ephemeral("u", "k") is None + + def test_incr_rate_counter(self): + uid = "u" + bucket = "b" + assert self.memory.incr_rate(uid, bucket, ttl_sec=1) == 1 + assert self.memory.incr_rate(uid, bucket, ttl_sec=1) == 2 + time.sleep(1.1) + assert self.memory.incr_rate(uid, bucket, ttl_sec=1) == 1 + + def test_mysql_driver_required(self): + with pytest.raises(RuntimeError): + RDSMemory(db_url="mysql://user:pass@localhost/test") diff --git a/vertex_flow/tests/test_memory_redis.py b/vertex_flow/tests/test_memory_redis.py new file mode 100644 index 0000000..cbac1aa --- /dev/null +++ b/vertex_flow/tests/test_memory_redis.py @@ -0,0 +1,139 @@ +"""Tests for RedisMemory implementation.""" + +import time + +from vertex_flow.memory import RedisMemory + + +class DummyPipeline: + def __init__(self, client): + self._client = client + self._commands = [] + + def lpush(self, *args): + self._commands.append(("lpush", args)) + return self + + def ltrim(self, *args): + self._commands.append(("ltrim", args)) + return self + + def incr(self, *args): + self._commands.append(("incr", args)) + return self + + def expire(self, *args): + self._commands.append(("expire", args)) + return self + + def execute(self): + results = [] + for cmd, args in self._commands: + results.append(getattr(self._client, cmd)(*args)) + self._commands.clear() + return results + + +class DummyRedis: + def __init__(self): + self._store = {} + self._lists = {} + + def _check_expired(self, key): + if key in self._store: + value, exp = self._store[key] + if exp is not None and time.time() > exp: + del self._store[key] + + def set(self, key, value, nx=False, ex=None): + self._check_expired(key) + if nx and key in self._store: + return None + expires_at = time.time() + ex if ex else None + self._store[key] = (value, expires_at) + return True + + def get(self, key): + self._check_expired(key) + if key not in self._store: + return None + return self._store[key][0] + + def delete(self, key): + self._store.pop(key, None) + + def lpush(self, key, value): + self._lists.setdefault(key, []) + self._lists[key].insert(0, value) + + def ltrim(self, key, start, end): + self._lists.setdefault(key, []) + self._lists[key] = self._lists[key][start : end + 1] + + def lrange(self, key, start, end): + lst = self._lists.get(key, []) + if end == -1: + end = len(lst) - 1 + return lst[start : end + 1] + + def pipeline(self): + return DummyPipeline(self) + + def incr(self, key): + self._check_expired(key) + value = int(self._store.get(key, ("0", None))[0]) + 1 + _, exp = self._store.get(key, (None, None)) + self._store[key] = (str(value), exp) + return value + + def expire(self, key, ttl): + if key in self._store: + value, _ = self._store[key] + self._store[key] = (value, time.time() + ttl) + return True + return False + + +class TestRedisMemory: + """Basic test cases for RedisMemory.""" + + def setup_method(self): + self.redis = DummyRedis() + self.memory = RedisMemory(client=self.redis, hist_maxlen=5) + + def test_seen_deduplication(self): + user_id = "user1" + key = "k" + assert self.memory.seen(user_id, key, ttl_sec=1) is False + assert self.memory.seen(user_id, key, ttl_sec=1) is True + time.sleep(1.1) + assert self.memory.seen(user_id, key, ttl_sec=1) is False + + def test_append_history_maxlen(self): + user_id = "user1" + for i in range(10): + self.memory.append_history(user_id, "user", "text", {"text": str(i)}, maxlen=5) + history = self.memory.recent_history(user_id, n=10) + assert len(history) == 5 + assert history[0]["content"]["text"] == "9" + assert history[4]["content"]["text"] == "5" + + def test_ctx_operations(self): + self.memory.ctx_set("u", "k", {"v": 1}) + assert self.memory.ctx_get("u", "k") == {"v": 1} + self.memory.ctx_del("u", "k") + assert self.memory.ctx_get("u", "k") is None + + def test_ephemeral_operations(self): + self.memory.set_ephemeral("u", "k", 1, ttl_sec=1) + assert self.memory.get_ephemeral("u", "k") == 1 + time.sleep(1.1) + assert self.memory.get_ephemeral("u", "k") is None + + def test_incr_rate_counter(self): + user_id = "u" + bucket = "b" + assert self.memory.incr_rate(user_id, bucket, ttl_sec=1) == 1 + assert self.memory.incr_rate(user_id, bucket, ttl_sec=1) == 2 + time.sleep(1.1) + assert self.memory.incr_rate(user_id, bucket, ttl_sec=1) == 1