diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..a9ba69c --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,238 @@ +# SmartFeed Architecture (medium-brief) + +## 1) What SmartFeed does + +SmartFeed builds one paginated feed from multiple client-provided sources (“subfeeds”) using a declarative tree config: + +- **Leaf**: `SubFeed` (calls one client method) +- **Mergers**: compose children (`append`, `distribute`, `positional`, `percentage`, `percentage_gradient`, `view_session`) +- **Wrapper**: `MergerDeduplication` (changes execution semantics around one child) + +Core runtime: + +- parse config -> create request `ExecutionContext` -> run tree via shared `Executor` -> return `FeedResult` + `next_page`. + + +## 2) Public surfaces and core data types + +### Public entrypoint + +- `FeedManager(config, methods_dict, redis_client=None)` + - `get_data(user_id, limit, next_page, **params) -> FeedResult` + +`methods_dict` maps config `method_name` strings to host-app callables. + +### Config schema surface + +`smartfeed.schemas` keeps stable imports for: + +- `FeedConfig`: top-level model (`version`, `feed`) +- `FeedTypes`: discriminated union by `type` + +### Cursor / pagination models + +- `FeedResultNextPageInside`: one node cursor (`page`, `after`) +- `FeedResultNextPage`: full-tree cursor map (`data: {node_id -> FeedResultNextPageInside}`) + +### Result models + +- `FeedResultClient`: required return type of client subfeed methods +- `FeedResult`: normalized return type of any SmartFeed node + + +## 3) Node interface contract + +All nodes inherit `BaseFeedConfigModel` and are executed through: + +- `get_data(methods_dict, user_id, limit, next_page, redis_client=None, ctx=None, **params) -> FeedResult` + +Important notes: + +- If a node implements `build_plan(...)`, executor uses the plan path. +- Base `get_data(...)` delegates back to executor and expects `build_plan(...)` to exist. +- Every node has `dedup_priority: int` (used by dedup arbitration/refill ordering). + + +## 4) ExecutionContext + +`ExecutionContext` is per-request state propagated through the tree: + +- `methods_dict`, `user_id`, `redis_client` +- `executor` (lazy via `ensure_executor()`) +- optional policy/settings: + - `dedup`: `DeduplicationPolicy` when dedup wrapper is active + - `refill_settings`: `RefillExecutionSettings(overfetch_factor, max_refill_loops)` + +Responsibilities: + +- centralize shared plumbing (executor + redis client) +- keep execution policies out of user params + + +## 5) Executor (runtime engine) + +Primary entry: + +- `Executor.run(node, ctx, limit, next_page, **params) -> FeedResult` + +Execution strategy: + +1. **Plan-first** + - `build_plan(...)` -> execute returned `Plan` + - otherwise call node `get_data(...)` +2. **Centralized concurrency** + - child runs use executor-managed `asyncio.gather(...)` +3. **Dedup/refill hooks** + - for non-slot nodes with `ctx.dedup`, run `DedupRuntime.run_node_with_dedup_refill(...)` + - for `SlotsPlan`, dedup/refill is handled inside slot execution + +`SlotsPlan` execution highlights: + +1. collect unique owners + demand per owner +2. fetch owners concurrently (with optional `owner_fetch_limits` overrides) +3. merge only changed cursor keys (`CursorMap.merge_delta`) +4. apply: + - dedup arbitration + refill (`apply_slots_plan_dedup`) when `ctx.dedup` exists + - refill-only deficits (`apply_slots_plan_refill`) when only `ctx.refill_settings` exists +5. consume slot schedule and call `assemble(...)` + +When dedup is active for a slots plan, owners are executed with `dedup=None` in owner context so global arbitration stays centralized. + + +## 6) Plans: declarative execution + +Plans separate “what to run” from “how to run it”. + +- `CallablePlan(fn)` + - node-provided async function with custom flow, still executed by executor + +- `SlotsPlan(ctx, limit, next_page, params, slots, assemble, owner_fetch_limits=None)` + - `slots`: ordered `SlotSpec(owner, max_count)` schedule + - `assemble(output, merged_next_page, owner_results)`: builds final `FeedResult` + + +## 7) Mergers and leaf responsibilities + +### SubFeed (leaf) + +- derives its local cursor from `next_page.data[subfeed_id]` (defaults page=1/after=None) +- calls `methods_dict[method_name]` +- passes only params present in method signature + `subfeed_params` +- async methods are awaited; sync methods run via `asyncio.to_thread(...)` +- `raise_error=False` converts method failure into empty `FeedResultClient` +- optional `shuffle` then normalizes to `FeedResult` + +### Slot-based mergers + +These build `SlotsPlan`: + +- `MergerAppend`: concatenation (optional shuffle) +- `MergerAppendDistribute` (`type="merger_distribute"`): append then redistribute by `distribution_key` +- `MergerPositional`: page-local slot ownership for `positional` vs `default`, keeps its own merger cursor +- `MergerPercentage`: integer allocation by percentages; when total is exactly 100, remainder is distributed to avoid underfill +- `MergerPercentageGradient`: two-owner percentage curve across the page, then advances merger page cursor + +### MergerViewSession (Redis-backed session cache) + +Goal: cache a session-sized list and serve slices. + +Flow: + +1. build cache key: `{merger_id}_{user_id}` + optional suffix from `custom_view_session_key` +2. check Redis `exists`; if no cache or no merger cursor in request -> regenerate session +3. on hit, `get`; if Redis returns `None` unexpectedly, regenerate +4. on generation: execute child once for `session_size`, optional dedup, store JSON with TTL +5. return page slice and increment merger cursor page +6. optional `shuffle` is applied to returned page slice (cache payload is not reshuffled) + +### MergerDeduplication (single-child wrapper) + +Goal: deduplicate while keeping child mix/slot semantics. + +Key behavior: + +- fresh session when merger cursor is absent or `page <= 0` + - reset descendant cursors + - for Redis backend, reset Redis seen-state key +- seen-state backend: + - `cursor`: encoded into merger cursor `after` + - `redis`: ZSET `dedup:{merger_id}:{user_id}` (+ optional custom suffix) +- builds `DeduplicationPolicy` + child `ExecutionContext(dedup=..., refill_settings=...)` +- executes child via shared executor, commits store, writes merger cursor (`page+1`, `after` for cursor backend) + +Refill/overfetch behavior: + +- duplicates trigger bounded refill loops (`max_refill_loops`) +- overfetch (`overfetch_factor`) is applied only for rewindable integer-offset cursors +- when overfetch is used, leaf cursor is rewound to inspected-count to avoid skipping unseen items + + +## 8) Dedup policy + seen stores + +### DeduplicationPolicy + +Owns key extraction + acceptance rules: + +- entity key from `dedup_key` + `missing_key_policy` +- reject duplicates already seen in current response (`seen_request_set`) +- compare candidate priority vs persisted seen priority + +Capabilities: + +- batched prefetch from store +- per-owner arbitration with deterministic tie-break: `(-dedup_priority, owner_rank, item_rank)` +- ordered single-stream acceptance (`accept_batch`) returning accepted items + inspected count + +### Seen stores + +- `CursorSeenStore` + - in-cursor map of `{key -> max_priority}` + - optional compression + max-key trimming at commit + +- `RedisSeenStore` + - cached reads via `redis_zmscore(...)` + - buffered writes via `redis_zadd_and_expire(...)` + + +## 9) Redis/JSON helpers + +- `_redis_call(client, method_name, *args, **kwargs)` + - async redis client: direct await + - sync redis client: `asyncio.to_thread(...)` + +Other helpers: + +- `jsonlib`: thin `orjson` wrapper compatible with package usage (`dumps`/`loads`) +- `dedup_utils`: cursor encode/decode + Redis ZSET helper fallbacks (`zmscore` / pipeline) + + +## 10) End-to-end call flows + +### A) Standard request (no view session, no dedup) + +1. `FeedManager.get_data(...)` builds `ExecutionContext` +2. `Executor.run(root, ctx, limit, next_page)` +3. recursive execution via plans or direct `get_data(...)` +4. returns `FeedResult(data, next_page, has_next_page)` + +### B) Slot-based merger request + +1. merger returns `SlotsPlan` +2. executor fetches owners concurrently +3. optional arbitration/refill runs +4. slots are consumed in schedule order +5. `assemble(...)` builds final result + +### C) Dedup wrapper request + +1. wrapper creates store + policy and child context +2. child executes under dedup/refill control +3. executor performs acceptance/arbitration + bounded refills +4. store commits; wrapper writes merger cursor state + +### D) View-session request + +1. wrapper resolves cache key +2. cache miss/new session -> regenerate and cache +3. cache hit -> load session list from Redis +4. return requested slice + advanced merger page diff --git a/Makefile b/Makefile index 7e9caf0..682148e 100644 --- a/Makefile +++ b/Makefile @@ -11,3 +11,17 @@ test: test_cache: pytest -s -vv -k "test_merger_view_session" + +.PHONY: test_async_chart charting + +# Runs only the async loop block + Chrome trace test. +# Writes trace.json next to this Makefile (project root). +test_async_chart: + rm -f ./trace.json + SMARTFEED_CHROME_TRACE=./trace.json pytest -q tests/test_async_loop_blocks_trace.py + @echo "\nWrote trace: $(CURDIR)/trace.json" + @echo "Open Chrome -> chrome://tracing -> Load -> select trace.json" + +# Convenience target: generate the trace + try to open chrome://tracing. +charting: test_async_chart + -@open -a "Google Chrome" "chrome://tracing" 2>/dev/null || true diff --git a/README.md b/README.md index fe96da4..a86e8a4 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,9 @@ Python-package для формирования ленты (Feed) из клиен - [Использование](#использование) - [Установка](#установка) - [Формирование конфигурации](#формирование-конфигурации) + - [MergerDeduplication (дедупликация)](#mergerdeduplication-дедупликация) + - [Параметры MergerDeduplication](#параметры-mergerdeduplication) + - [Важные нюансы (сброс, cursor/redis, overfetch)](#важные-нюансы-сброс-cursorredis-overfetch) - [Требования к клиентскому методу](#требования-к-клиентскому-методу) - [Запуск](#запуск) @@ -68,6 +71,115 @@ poetry add git+ssh://git@github.com:epoch8/looky-timeline.git }, ``` +### MergerDeduplication (дедупликация) + +MergerDeduplication — обёртка над одним дочерним узлом (merger или subfeed), которая удаляет дубли по ключу. + +Ключевые свойства реализации: + +- Дедупликация выполняется на уровне листьев (SubFeed), а не пост-обработкой результата мерджера. + Это важно: вложенные мерджеры (positional/percentage/gradient/append/distribute) сохраняют свои правила смешивания. + Если элемент удалён как дубль, MergerDeduplication «дозапросит» следующий элемент из того же источника. +- Состояние «уже видели» может храниться: + - в курсоре (state_backend="cursor") — удобно без Redis, но курсор может расти; + - в Redis (state_backend="redis") — удобно для большого состояния. + +Пример: обернуть существующую конфигурацию фида дедупликацией: + +```json +{ + "version": "1", + "feed": { + "merger_id": "dedup_main", + "type": "merger_deduplication", + "dedup_key": "id", + "missing_key_policy": "error", + "state_backend": "cursor", + "cursor_compress": true, + "cursor_max_keys": 2000, + "overfetch_factor": 2, + "max_refill_loops": 20, + "data": { + "merger_id": "merger_percent", + "type": "merger_percentage", + "items": [ + { + "percentage": 60, + "data": { + "subfeed_id": "sf_posts", + "type": "subfeed", + "method_name": "posts", + "dedup_priority": 10 + } + }, + { + "percentage": 40, + "data": { + "subfeed_id": "sf_ads", + "type": "subfeed", + "method_name": "ads", + "dedup_priority": 0 + } + } + ] + } + } +} +``` + +В примере выше, если `posts` и `ads` отдают объекты с одинаковым `id`, то «побеждает» источник с большим `dedup_priority`. + +### Параметры MergerDeduplication + +Обязательные поля: + +- `merger_id: str` — уникальный ID мерджера. +- `type: "merger_deduplication"` +- `data` — ровно один дочерний узел (subfeed или merger). + +Поля дедупликации: + +- `dedup_key: str | null` — имя ключа/атрибута для поиска дублей. + - если `null`, ключом считается сам объект (подходит, когда объекты уже hashable/строковые). +- `missing_key_policy: "error" | "keep" | "drop"` (default: `"error"`) + - `error`: выбросить ошибку, если у элемента нет `dedup_key`; + - `keep`: сохранить элемент, даже если ключа нет; + - `drop`: выкинуть элемент без ключа. + +Состояние seen (межстраничная дедупликация): + +- `state_backend: "cursor" | "redis"` (default: `"cursor"`) +- `state_ttl_seconds: int` (default: `3600`) — TTL для Redis состояния (только для backend=`redis`). +- `cursor_compress: bool` (default: `true`) — сжимать seen-состояние в cursor backend. +- `cursor_max_keys: int | null` — ограничить размер seen-состояния в cursor backend (полезно для контроля размера курсора). + +Производительность/поведение: + +- `overfetch_factor: int` (default: `1`) — «перезапрос» внутри листьев, чтобы быстрее добрать `limit` без множества рефиллов. +- `max_refill_loops: int` (default: `20`) — верхняя граница количества дозапросов на один лист. + +### Важные нюансы (сброс, cursor/redis, overfetch) + +- Сброс состояния при `page <= 0` или отсутствии курсора для `merger_id`. + - MergerDeduplication воспринимает это как «fresh session» и очищает курсоры всех дочерних узлов. + - Для backend=`redis` дополнительно удаляет ключ состояния в Redis. + +- Если `state_backend="redis"`, нужно передать `redis_client` в `FeedManager`. + - Ключ состояния в Redis строится как `dedup:{merger_id}:{user_id}`. + - Можно добавить суффикс через параметр запроса `custom_deduplication_key` (или `custom_view_session_key`), + чтобы разделять состояния для разных режимов выдачи. + +- Приоритет (`dedup_priority`) — это приоритет победы при конфликте дублей, а не порядок вывода. + - Больше `dedup_priority` → элемент «побеждает» и будет считаться seen с этим приоритетом. + - Это поле доступно у всех узлов (merger/subfeed) и используется MergerDeduplication при дедупликации. + +- overfetch работает безопасно только для «перематываемых» курсоров. + - Сейчас overfetch включается только если `next_page.after` у листа — целочисленный offset. + - Если `after` — строка/словарь/любой другой объект, он считается непрозрачным и overfetch не применяется. + +- Главный реальный bottleneck в дедупликации — не обёртки/копии, а рефиллы. + - Если дублей много и upstream-методы дорогие, стоит аккуратно подобрать `overfetch_factor` и `max_refill_loops`. + ### Требования к клиентскому методу Клиентский метод для получения данных должен обязательно включать в себя следующие параметры: diff --git a/pyproject.toml b/pyproject.toml index 108a4c7..2c4d669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ packages = [ python = ">=3.9" pydantic = ">=1.10.7" redis = ">=4.5.5" +orjson = ">=3.9.0" [tool.poetry.group.dev.dependencies] isort = "^5.12.0" diff --git a/smartfeed/examples/example_client.py b/smartfeed/examples/example_client.py index 9a421ff..a24e130 100644 --- a/smartfeed/examples/example_client.py +++ b/smartfeed/examples/example_client.py @@ -1,16 +1,15 @@ import base64 -import json from typing import Optional, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator +from smartfeed import jsonlib as json +from smartfeed.pydantic_compat import parse_model from smartfeed.schemas import FeedResultClient, FeedResultNextPage, FeedResultNextPageInside class TestClientRequest(BaseModel): - """ - Пример модели клиентского входящего запроса. - """ + """Example client request model.""" profile_id: str = Field(...) limit: int = Field(...) @@ -18,20 +17,20 @@ class TestClientRequest(BaseModel): base64.urlsafe_b64encode(json.dumps({"data": {}}).encode()).decode() ) - class Config: - validate_all = True + model_config = ConfigDict(validate_default=True) - @validator("next_page") + @field_validator("next_page") + @classmethod def validate_next_page(cls, value: Union[str, FeedResultNextPage]) -> Union[str, FeedResultNextPage]: if isinstance(value, str): - return FeedResultNextPage.parse_obj(json.loads(base64.urlsafe_b64decode(value))) + payload = json.loads(base64.urlsafe_b64decode(value)) + return parse_model(FeedResultNextPage, payload) + return value class ClientMixerClass: - """ - Пример клиентского класса ClientMixer. - """ + """Example client methods for SmartFeed.""" @staticmethod async def example_method( @@ -40,16 +39,6 @@ async def example_method( next_page: FeedResultNextPageInside, limit_to_return: Optional[int] = None, ) -> FeedResultClient: - """ - Пример клиентского метода. - - :param user_id: ID профиля. - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param limit_to_return: ограничить кол-во результата. - :return: массив букв "profile_id" в количестве "limit" штук. - """ - data = [f"{user_id}_{i}" for i in range(1, 1000)] from_index = (data.index(next_page.after) + 1) if next_page.after else 0 @@ -62,9 +51,7 @@ async def example_method( next_page.after = result_data[-1] if result_data else None next_page.page += 1 - - result = FeedResultClient(data=result_data, next_page=next_page, has_next_page=True) - return result + return FeedResultClient(data=result_data, next_page=next_page, has_next_page=True) @staticmethod async def empty_method( @@ -73,21 +60,9 @@ async def empty_method( next_page: FeedResultNextPageInside, limit_to_return: Optional[int] = None, # pylint: disable=W0613 ) -> FeedResultClient: - """ - Пример клиентского метода, возвращающего пустые данные. - - :param user_id: ID профиля. - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param limit_to_return: ограничить кол-во результата. - :return: массив букв "profile_id" в количестве "limit" штук. - """ - next_page.after = None next_page.page += 1 - - result = FeedResultClient(data=[], next_page=next_page, has_next_page=False) - return result + return FeedResultClient(data=[], next_page=next_page, has_next_page=False) @staticmethod async def error_method( @@ -96,21 +71,9 @@ async def error_method( next_page: FeedResultNextPageInside, limit_to_return: Optional[int] = None, # pylint: disable=W0613 ) -> FeedResultClient: - """ - Пример клиентского метода, возвращающего пустые данные. - - :param user_id: ID профиля. - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param limit_to_return: ограничить кол-во результата. - :return: массив букв "profile_id" в количестве "limit" штук. - """ - next_page.after = None next_page.page = int(10 / 0) - - result = FeedResultClient(data=[], next_page=next_page, has_next_page=False) - return result + return FeedResultClient(data=[], next_page=next_page, has_next_page=False) @staticmethod async def doubles_method( @@ -119,23 +82,11 @@ async def doubles_method( next_page: FeedResultNextPageInside, limit_to_return: Optional[int] = None, # pylint: disable=W0613 ) -> FeedResultClient: - """ - Пример клиентского метода, возвращающего данные с дублями. - - :param user_id: ID профиля. - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param limit_to_return: ограничить кол-во результата. - :return: массив целых чисел, равный [i for i in range(1, 11)] после удаления дублей. - """ - data = [1, 2, 3, 4, 3, 2, 5, 6, 4, 4, 7, 8, 9, 10, 9, 9, 9] next_page.after = None next_page.page += 1 - - result = FeedResultClient(data=data, next_page=next_page, has_next_page=False) - return result + return FeedResultClient(data=data, next_page=next_page, has_next_page=False) @staticmethod async def keys_method( @@ -144,16 +95,6 @@ async def keys_method( next_page: FeedResultNextPageInside, limit_to_return: Optional[int] = None, ) -> FeedResultClient: - """ - Пример клиентского метода. - - :param user_id: ID профиля. - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param limit_to_return: ограничить кол-во результата. - :return: массив букв "profile_id" в количестве "limit" штук. - """ - data = [{"user_id": f"{user_id}_{i%10}", "value": i} for i in range(1, 1000)] from_index = (data.index(next_page.after) + 1) if next_page.after else 0 @@ -166,6 +107,4 @@ async def keys_method( next_page.after = result_data[-1] if result_data else None next_page.page += 1 - - result = FeedResultClient(data=result_data, next_page=next_page, has_next_page=True) - return result + return FeedResultClient(data=result_data, next_page=next_page, has_next_page=True) diff --git a/smartfeed/execution/context.py b/smartfeed/execution/context.py new file mode 100644 index 0000000..18fc769 --- /dev/null +++ b/smartfeed/execution/context.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Union + +import redis +from redis.asyncio import Redis as AsyncRedis + + +@dataclass +class ExecutionContext: + """Execution context propagated through the feed tree. + + Keeps internal state (policies, backends) out of user params. + """ + + methods_dict: Dict[str, Callable] + user_id: Any + redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None + + # Assigned by the caller (FeedManager / tests) to avoid circular imports. + executor: Any = None + + # Policies (optional) + dedup: Optional[object] = None + + # Execution settings (optional) + refill_settings: Optional["RefillExecutionSettings"] = None + + def ensure_redis_client(self, redis_client: Optional[Union[redis.Redis, AsyncRedis]]) -> None: + if self.redis_client is None and redis_client is not None: + self.redis_client = redis_client + + def ensure_executor(self) -> Any: + if self.executor is None: + from .executor import Executor + + self.executor = Executor() + return self.executor + + +@dataclass(frozen=True) +class RefillExecutionSettings: + overfetch_factor: int = 1 + max_refill_loops: int = 20 diff --git a/smartfeed/execution/cursors.py b/smartfeed/execution/cursors.py new file mode 100644 index 0000000..eef2fd0 --- /dev/null +++ b/smartfeed/execution/cursors.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + +from ..feed_models import BaseFeedConfigModel, FeedResultNextPage + + +@dataclass +class CursorMap: + next_page: FeedResultNextPage + + def merge_delta(self, *, base_next_page: FeedResultNextPage, owner_next_page: FeedResultNextPage) -> None: + """Merge only the cursor keys that actually changed.""" + + for key, value in owner_next_page.data.items(): + base_value = base_next_page.data.get(key) + if base_value == value: + continue + self.next_page.data[key] = value + + def reset_keys(self, keys: Iterable[str]) -> None: + for key in keys: + self.next_page.data.pop(key, None) + + @staticmethod + def can_overfetch(*, node: BaseFeedConfigModel, base_next_page: FeedResultNextPage) -> bool: + sub_id = getattr(node, "subfeed_id", None) + if not isinstance(sub_id, str): + return False + entry = base_next_page.data.get(sub_id) + if entry is None: + return False + return isinstance(entry.after, int) + + @staticmethod + def rewind_overfetch( + *, + node: BaseFeedConfigModel, + base_next_page: FeedResultNextPage, + result_next_page: FeedResultNextPage, + inspected_count: int, + batch_size: int, + ) -> None: + sub_id = getattr(node, "subfeed_id", None) + if not isinstance(sub_id, str): + return + if sub_id not in result_next_page.data: + return + + entry = result_next_page.data[sub_id] + end_after = entry.after + if not isinstance(end_after, int): + return + + base_entry = base_next_page.data.get(sub_id) + prev_after = base_entry.after if base_entry is not None else None + if not isinstance(prev_after, int): + return + + expected_end = prev_after + batch_size + if end_after == expected_end: + entry.after = prev_after + inspected_count diff --git a/smartfeed/execution/dedup_runtime.py b/smartfeed/execution/dedup_runtime.py new file mode 100644 index 0000000..1b9b26c --- /dev/null +++ b/smartfeed/execution/dedup_runtime.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage +from .context import ExecutionContext +from .cursors import CursorMap +from .plans import SlotsPlan + +if TYPE_CHECKING: + from .executor import Executor + + +class DedupRuntime: + """Dedup/refill orchestration. + + This owns the control flow (refill loops, slot deficit refills, etc.) + while `DeduplicationPolicy` stays focused on acceptance/arbitration decisions. + """ + + def __init__(self, executor: "Executor") -> None: + self._executor = executor + + def _get_refill_settings(self, ctx: ExecutionContext) -> Any: + return getattr(ctx, "refill_settings", None) + + async def run_node_with_dedup_refill( + self, + *, + node: BaseFeedConfigModel, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + params: Dict[str, Any], + initial_result: FeedResult, + ) -> FeedResult: + dedup = getattr(ctx, "dedup", None) + if dedup is None: + return initial_result + + settings = self._get_refill_settings(ctx) + overfetch_factor = max(1, int(getattr(settings, "overfetch_factor", 1))) + max_refill_loops = max(1, int(getattr(settings, "max_refill_loops", 20))) + priority = int(getattr(node, "dedup_priority", 0)) + + collected: List[Any] = [] + remaining = int(limit) + loops = 0 + + base_next_page = next_page + current_result = initial_result + request_limit = max(1, remaining) + + # NOTE: Refill loops are inherently sequential for a single node because + # each subsequent request depends on the previous cursor. + while remaining > 0: + can_overfetch = CursorMap.can_overfetch(node=node, base_next_page=base_next_page) + + accepted, inspected_count = await dedup.accept_batch( + items=list(current_result.data), + priority=priority, + limit=remaining, + ) + + if can_overfetch and request_limit > remaining: + CursorMap.rewind_overfetch( + node=node, + base_next_page=base_next_page, + result_next_page=current_result.next_page, + inspected_count=inspected_count, + batch_size=len(current_result.data), + ) + + if accepted: + collected.extend(accepted) + remaining = limit - len(collected) + + if remaining <= 0 or not current_result.has_next_page or loops >= max_refill_loops: + break + loops += 1 + + base_next_page = current_result.next_page + request_limit = max(1, remaining) + if CursorMap.can_overfetch(node=node, base_next_page=base_next_page) and overfetch_factor > 1: + request_limit = max(1, remaining * overfetch_factor) + + current_result, _plan = await self._executor._run_node_raw( + node, + ctx, + request_limit, + base_next_page, + params, + ) + + return FeedResult( + data=collected, + next_page=current_result.next_page, + has_next_page=bool(current_result.has_next_page), + ) + + async def apply_slots_plan_dedup( + self, + *, + plan: SlotsPlan, + owners: List[Any], + owner_index: Dict[int, int], + owner_buffers: Dict[int, List[Any]], + owner_results: Dict[int, FeedResult], + dedup_policy: Any, + refill_settings: Any, + cursor: CursorMap, + ) -> Tuple[Dict[int, List[Any]], Dict[int, FeedResult]]: + owner_buffers = await dedup_policy.arbitrate_owner_buffers( + owners=owners, + owner_buffers=owner_buffers, + owner_rank=owner_index, + ) + + for owner in owners: + owner_id = id(owner) + if owner_id not in owner_results: + continue + old = owner_results[owner_id] + owner_results[owner_id] = FeedResult( + data=list(owner_buffers.get(owner_id, [])), + next_page=old.next_page, + has_next_page=old.has_next_page, + ) + + deficits = self._compute_slot_deficits(plan=plan, owner_buffers=owner_buffers) + if deficits: + await self._refill_deficits( + plan=plan, + deficits=deficits, + owners=owners, + owner_index=owner_index, + owner_buffers=owner_buffers, + owner_results=owner_results, + dedup_policy=dedup_policy, + refill_settings=refill_settings, + cursor=cursor, + ) + + return owner_buffers, owner_results + + async def apply_slots_plan_refill( + self, + *, + plan: SlotsPlan, + owners: List[Any], + owner_index: Dict[int, int], + owner_buffers: Dict[int, List[Any]], + owner_results: Dict[int, FeedResult], + refill_settings: Any, + cursor: CursorMap, + ) -> Tuple[Dict[int, List[Any]], Dict[int, FeedResult]]: + deficits = self._compute_slot_deficits(plan=plan, owner_buffers=owner_buffers) + if deficits: + await self._refill_deficits_without_dedup( + plan=plan, + deficits=deficits, + owners=owners, + owner_index=owner_index, + owner_buffers=owner_buffers, + owner_results=owner_results, + refill_settings=refill_settings, + cursor=cursor, + ) + + return owner_buffers, owner_results + + def _compute_slot_deficits(self, *, plan: SlotsPlan, owner_buffers: Dict[int, List[Any]]) -> Dict[int, int]: + quota_schedule = sum(int(s.max_count) for s in plan.slots) <= int(plan.limit) + deficits: Dict[int, int] = {} + consumed: Dict[int, int] = {} + remaining = int(plan.limit) + deficit_slots: List[int] = [] + + for slot in plan.slots: + if remaining <= 0: + break + + owner_id = id(slot.owner) + want = min(int(slot.max_count), remaining) + if want <= 0: + continue + + have_total = len(owner_buffers.get(owner_id, [])) + already = int(consumed.get(owner_id, 0)) + available = max(0, have_total - already) + take = min(want, available) + missing = max(0, want - take) + if missing: + deficit_slots.append(owner_id) + if quota_schedule: + deficits[owner_id] = deficits.get(owner_id, 0) + missing + consumed[owner_id] = already + take + remaining -= take + + if quota_schedule: + return deficits + if remaining <= 0: + return {} + fallback_owner_id: Optional[int] = ( + deficit_slots[-1] if deficit_slots else (id(plan.slots[-1].owner) if plan.slots else None) + ) + return {fallback_owner_id: remaining} if fallback_owner_id is not None else {} + + async def _refill_deficits( + self, + *, + plan: SlotsPlan, + deficits: Dict[int, int], + owners: List[Any], + owner_index: Dict[int, int], + owner_buffers: Dict[int, List[Any]], + owner_results: Dict[int, FeedResult], + dedup_policy: Any, + refill_settings: Any, + cursor: CursorMap, + ) -> None: + overfetch_factor = max(1, int(getattr(refill_settings, "overfetch_factor", 1))) + max_refill_loops = max(1, int(getattr(refill_settings, "max_refill_loops", 20))) + + deficit_owners: List[Any] = [o for o in owners if id(o) in deficits] + deficit_owners = sorted( + deficit_owners, + key=lambda o: ( + int(getattr(o, "dedup_priority", 0)), + owner_index.get(id(o), 0), + ), + ) + + state: Dict[int, Dict[str, Any]] = {} + for refill_owner in deficit_owners: + refill_owner_id = id(refill_owner) + missing_total = int(deficits.get(refill_owner_id, 0)) + if missing_total <= 0: + continue + + base_np = owner_results[refill_owner_id].next_page if refill_owner_id in owner_results else plan.next_page + state[refill_owner_id] = { + "missing_total": missing_total, + "remaining": missing_total, + "accepted": [], + "loops": 0, + "current_next_page": base_np, + "has_next_page": True, + } + + if not state: + return + + while True: + wave_ops: List[Tuple[Any, int, FeedResultNextPage, int, bool]] = [] + for refill_owner in deficit_owners: + refill_owner_id = id(refill_owner) + owner_state = state.get(refill_owner_id) + if owner_state is None: + continue + if owner_state["remaining"] <= 0: + continue + if not owner_state["has_next_page"]: + continue + if owner_state["loops"] >= max_refill_loops: + continue + + base_np = owner_state["current_next_page"] + remaining_before = max(1, int(owner_state["remaining"])) + request_limit = remaining_before + can_overfetch = CursorMap.can_overfetch(node=refill_owner, base_next_page=base_np) + if can_overfetch and overfetch_factor > 1: + request_limit = max(1, remaining_before * overfetch_factor) + + wave_ops.append((refill_owner, refill_owner_id, base_np, request_limit, can_overfetch)) + + if not wave_ops: + break + + results = await self._executor.gather( + *[ + self._executor._run_owner( + plan=plan, + owner=owner, + demand=request_limit, + base_next_page=base_np, + dedup_active=True, + ) + for owner, _owner_id, base_np, request_limit, _can_overfetch in wave_ops + ] + ) + + for (owner, owner_id, base_np, request_limit, can_overfetch), result in zip(wave_ops, results): + owner_state = state[owner_id] + remaining_before = int(owner_state["remaining"]) + + owner_state["current_next_page"] = result.next_page + owner_state["has_next_page"] = bool(result.has_next_page) + cursor.merge_delta(base_next_page=plan.next_page, owner_next_page=result.next_page) + + refill_prio = int(getattr(owner, "dedup_priority", 0)) + wave_accepted, inspected_count = await dedup_policy.accept_batch( + items=list(result.data), + priority=refill_prio, + limit=max(0, remaining_before), + ) + + if can_overfetch and request_limit > remaining_before: + CursorMap.rewind_overfetch( + node=owner, + base_next_page=base_np, + result_next_page=result.next_page, + inspected_count=inspected_count, + batch_size=len(result.data), + ) + + if wave_accepted: + owner_state["accepted"].extend(wave_accepted) + owner_state["remaining"] = int(owner_state["missing_total"]) - len(owner_state["accepted"]) + + if owner_state["remaining"] > 0 and owner_state["has_next_page"]: + owner_state["loops"] += 1 + + for refill_owner in deficit_owners: + refill_owner_id = id(refill_owner) + owner_state = state.get(refill_owner_id) + if owner_state is None: + continue + + accepted = owner_state["accepted"] + if accepted: + owner_buffers.setdefault(refill_owner_id, []) + owner_buffers[refill_owner_id].extend(accepted) + + owner_results[refill_owner_id] = FeedResult( + data=list(owner_buffers.get(refill_owner_id, [])), + next_page=owner_state["current_next_page"], + has_next_page=owner_state["has_next_page"], + ) + + async def _refill_deficits_without_dedup( + self, + *, + plan: SlotsPlan, + deficits: Dict[int, int], + owners: List[Any], + owner_index: Dict[int, int], + owner_buffers: Dict[int, List[Any]], + owner_results: Dict[int, FeedResult], + refill_settings: Any, + cursor: CursorMap, + ) -> None: + max_refill_loops = max(1, int(getattr(refill_settings, "max_refill_loops", 20))) + + deficit_owners: List[Any] = [o for o in owners if id(o) in deficits] + deficit_owners = sorted( + deficit_owners, + key=lambda o: ( + int(getattr(o, "dedup_priority", 0)), + owner_index.get(id(o), 0), + ), + ) + + state: Dict[int, Dict[str, Any]] = {} + for refill_owner in deficit_owners: + refill_owner_id = id(refill_owner) + missing_total = int(deficits.get(refill_owner_id, 0)) + if missing_total <= 0: + continue + + base_np = owner_results[refill_owner_id].next_page if refill_owner_id in owner_results else plan.next_page + state[refill_owner_id] = { + "missing_total": missing_total, + "remaining": missing_total, + "accepted": [], + "loops": 0, + "current_next_page": base_np, + "has_next_page": True, + } + + if not state: + return + + while True: + wave_ops: List[Tuple[Any, int, FeedResultNextPage, int]] = [] + for refill_owner in deficit_owners: + refill_owner_id = id(refill_owner) + owner_state = state.get(refill_owner_id) + if owner_state is None: + continue + if owner_state["remaining"] <= 0: + continue + if not owner_state["has_next_page"]: + continue + if owner_state["loops"] >= max_refill_loops: + continue + + base_np = owner_state["current_next_page"] + request_limit = max(1, int(owner_state["remaining"])) + wave_ops.append((refill_owner, refill_owner_id, base_np, request_limit)) + + if not wave_ops: + break + + results = await self._executor.gather( + *[ + self._executor._run_owner( + plan=plan, + owner=owner, + demand=request_limit, + base_next_page=base_np, + dedup_active=False, + ) + for owner, _owner_id, base_np, request_limit in wave_ops + ] + ) + + for (_owner, owner_id, _base_np, _request_limit), result in zip(wave_ops, results): + owner_state = state[owner_id] + remaining_before = int(owner_state["remaining"]) + + owner_state["current_next_page"] = result.next_page + owner_state["has_next_page"] = bool(result.has_next_page) + cursor.merge_delta(base_next_page=plan.next_page, owner_next_page=result.next_page) + + if remaining_before > 0: + owner_state["accepted"].extend(list(result.data)[:remaining_before]) + owner_state["remaining"] = int(owner_state["missing_total"]) - len(owner_state["accepted"]) + + if owner_state["remaining"] > 0 and owner_state["has_next_page"]: + owner_state["loops"] += 1 + + for refill_owner in deficit_owners: + refill_owner_id = id(refill_owner) + owner_state = state.get(refill_owner_id) + if owner_state is None: + continue + + accepted = owner_state["accepted"] + if accepted: + owner_buffers.setdefault(refill_owner_id, []) + owner_buffers[refill_owner_id].extend(accepted) + + owner_results[refill_owner_id] = FeedResult( + data=list(owner_buffers.get(refill_owner_id, [])), + next_page=owner_state["current_next_page"], + has_next_page=owner_state["has_next_page"], + ) diff --git a/smartfeed/execution/executor.py b/smartfeed/execution/executor.py new file mode 100644 index 0000000..d6a92e8 --- /dev/null +++ b/smartfeed/execution/executor.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import asyncio +import inspect +from typing import Any, Dict, List, Optional, Tuple + +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage, _pydantic_deep_copy +from .context import ExecutionContext +from .cursors import CursorMap +from .dedup_runtime import DedupRuntime +from .plans import CallablePlan, Plan, SlotSpec, SlotsPlan + + +class Executor: + """Shared execution engine. + + Owns recursion and concurrency. Nodes can optionally expose `build_plan(...)`. + """ + + async def run( + self, + node: BaseFeedConfigModel, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> FeedResult: + result, plan = await self._run_node_raw(node, ctx, limit, next_page, params) + + dedup = getattr(ctx, "dedup", None) + if dedup is None: + return result + + if isinstance(plan, SlotsPlan): + return result + + return await self._dedup_runtime().run_node_with_dedup_refill( + node=node, + ctx=ctx, + limit=limit, + next_page=next_page, + params=params, + initial_result=result, + ) + + def _dedup_runtime(self) -> DedupRuntime: + runtime = getattr(self, "_dedup_runtime_instance", None) + if runtime is None: + runtime = DedupRuntime(self) + setattr(self, "_dedup_runtime_instance", runtime) + return runtime + + async def execute_plan(self, plan: Plan) -> FeedResult: + """Interpret and execute a declarative plan. + + Plans must not perform execution themselves; they are data structures. + """ + + if isinstance(plan, SlotsPlan): + return await self._execute_slots_plan(plan) + if isinstance(plan, CallablePlan): + return await plan.fn(self) + raise TypeError(f"Unknown plan type: {type(plan)!r}") + + async def gather(self, *coros: Any) -> List[Any]: + """Execute coroutines concurrently. + + Centralizes concurrency in the executor layer. + """ + + return list(await asyncio.gather(*coros)) + + async def _maybe_await(self, value: Any) -> Any: + if inspect.isawaitable(value): + return await value + return value + + async def _run_node_raw( + self, + node: BaseFeedConfigModel, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + params: Dict[str, Any], + ) -> Tuple[FeedResult, Optional[Plan]]: + build_plan = getattr(node, "build_plan", None) + if callable(build_plan): + plan: Plan = build_plan(ctx=ctx, limit=limit, next_page=next_page, **params) + result = await self.execute_plan(plan) + return result, plan + + result = await node.get_data( + methods_dict=ctx.methods_dict, + user_id=ctx.user_id, + limit=limit, + next_page=next_page, + redis_client=ctx.redis_client, + ctx=ctx, + **params, + ) + return result, None + + async def _execute_slots_plan(self, plan: SlotsPlan) -> FeedResult: + if plan.limit <= 0: + assembled = await self._maybe_await(plan.assemble([], plan.next_page, {})) + return assembled + + working_next_page = _pydantic_deep_copy(plan.next_page) + cursor = CursorMap(working_next_page) + owners, owner_index, owner_max_demand = self._collect_plan_owners(plan) + dedup_policy = getattr(plan.ctx, "dedup", None) + refill_settings = getattr(plan.ctx, "refill_settings", None) + dedup_active = dedup_policy is not None + + owner_buffers, owner_results = await self._run_plan_owners( + plan=plan, + owners=owners, + owner_max_demand=owner_max_demand, + dedup_active=dedup_active, + cursor=cursor, + ) + + if dedup_policy is not None: + owner_buffers, owner_results = await self._dedup_runtime().apply_slots_plan_dedup( + plan=plan, + owners=owners, + owner_index=owner_index, + owner_buffers=owner_buffers, + owner_results=owner_results, + dedup_policy=dedup_policy, + refill_settings=refill_settings, + cursor=cursor, + ) + elif refill_settings is not None: + owner_buffers, owner_results = await self._dedup_runtime().apply_slots_plan_refill( + plan=plan, + owners=owners, + owner_index=owner_index, + owner_buffers=owner_buffers, + owner_results=owner_results, + refill_settings=refill_settings, + cursor=cursor, + ) + + output = self._consume_slots(plan=plan, owner_buffers=owner_buffers) + assembled = await self._maybe_await(plan.assemble(output, cursor.next_page, owner_results)) + return assembled + + def _collect_plan_owners(self, plan: SlotsPlan) -> tuple[List[Any], Dict[int, int], Dict[int, int]]: + owners: List[Any] = [] + owner_index: Dict[int, int] = {} + owner_demand: Dict[int, int] = {} + for slot in plan.slots: + owner_id = id(slot.owner) + if owner_id not in owner_index: + owner_index[owner_id] = len(owners) + owners.append(slot.owner) + owner_demand[owner_id] = owner_demand.get(owner_id, 0) + int(slot.max_count) + return owners, owner_index, owner_demand + + async def _run_owner( + self, + *, + plan: SlotsPlan, + owner: Any, + demand: int, + base_next_page: FeedResultNextPage, + dedup_active: bool, + ) -> FeedResult: + isolated_next_page = _pydantic_deep_copy(base_next_page) + owner_ctx = plan.ctx + if dedup_active: + owner_ctx = ExecutionContext( + methods_dict=plan.ctx.methods_dict, + user_id=plan.ctx.user_id, + redis_client=plan.ctx.redis_client, + executor=plan.ctx.executor, + dedup=None, + # Keep refill settings so nested slots plans can still compensate + # owner deficits while top-level dedup arbitration remains centralized. + refill_settings=plan.ctx.refill_settings, + ) + return await self.run(owner, owner_ctx, demand, isolated_next_page, **plan.params) + + async def _run_plan_owners( + self, + *, + plan: SlotsPlan, + owners: List[Any], + owner_max_demand: Dict[int, int], + dedup_active: bool, + cursor: CursorMap, + ) -> tuple[Dict[int, List[Any]], Dict[int, FeedResult]]: + owner_buffers: Dict[int, List[Any]] = {id(o): [] for o in owners} + owner_results: Dict[int, FeedResult] = {} + + ops: List[tuple[Any, int]] = [] + for owner in owners: + if plan.owner_fetch_limits is not None and id(owner) in plan.owner_fetch_limits: + demand = int(plan.owner_fetch_limits[id(owner)]) + else: + demand = min(plan.limit, int(owner_max_demand.get(id(owner), 0))) + if demand > 0: + ops.append((owner, demand)) + + if not ops: + return owner_buffers, owner_results + + results = await self.gather( + *[ + self._run_owner( + plan=plan, + owner=owner, + demand=demand, + base_next_page=plan.next_page, + dedup_active=dedup_active, + ) + for owner, demand in ops + ] + ) + for (owner, _demand), owner_result in zip(ops, results): + owner_results[id(owner)] = owner_result + owner_buffers[id(owner)] = list(owner_result.data) + cursor.merge_delta( + base_next_page=plan.next_page, + owner_next_page=owner_result.next_page, + ) + + return owner_buffers, owner_results + + def _consume_slots(self, *, plan: SlotsPlan, owner_buffers: Dict[int, List[Any]]) -> List[Any]: + output: List[Any] = [] + for slot in plan.slots: + if len(output) >= plan.limit: + break + + remaining = plan.limit - len(output) + take = min(int(slot.max_count), remaining) + if take <= 0: + continue + + owner_buffer = owner_buffers.get(id(slot.owner), []) + if not owner_buffer: + continue + + chunk = owner_buffer[:take] + del owner_buffer[: len(chunk)] + output.extend(chunk) + + return output + + +__all__ = [ + "Executor", + "Plan", + "CallablePlan", + "SlotSpec", + "SlotsPlan", +] diff --git a/smartfeed/execution/plans.py b/smartfeed/execution/plans.py new file mode 100644 index 0000000..8a23a29 --- /dev/null +++ b/smartfeed/execution/plans.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, List, Optional, Protocol + +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage +from .context import ExecutionContext + + +class Plan(Protocol): + """Declarative execution plan. + + Plans describe what to run; the `Executor` is responsible for interpreting + and executing them. + """ + + +@dataclass(frozen=True) +class CallablePlan: + """A plan implemented as an async callable. + + Useful for mergers whose child limits depend on previous child results. + """ + + fn: Callable[["Executor"], Awaitable[FeedResult]] + + +@dataclass(frozen=True) +class SlotSpec: + """A slot segment owned by a child node. + + Output order is defined by the sequence of SlotSpecs. + """ + + owner: BaseFeedConfigModel + max_count: int + + +@dataclass(frozen=True) +class SlotsPlan: + """Plan expressed as slot ownership + an assembly function. + + The executor will fetch children (possibly in priority order) and then assemble + results in the slot schedule order. + """ + + ctx: ExecutionContext + limit: int + next_page: FeedResultNextPage + params: Dict[str, Any] + slots: List[SlotSpec] + assemble: Callable[[List[Any], FeedResultNextPage, Dict[int, FeedResult]], Any] + owner_fetch_limits: Optional[Dict[int, int]] = None + + +# NOTE: `Executor` is imported only for typing to avoid an import cycle. +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .executor import Executor diff --git a/smartfeed/feed_models.py b/smartfeed/feed_models.py new file mode 100644 index 0000000..38d26d6 --- /dev/null +++ b/smartfeed/feed_models.py @@ -0,0 +1,197 @@ +import asyncio +import inspect +from dataclasses import dataclass +from random import shuffle +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Literal, Optional, Union, cast + +import redis +from pydantic import BaseModel +from redis.asyncio import Redis as AsyncRedis +from redis.asyncio import RedisCluster as AsyncRedisCluster + +if TYPE_CHECKING: + from .execution.context import ExecutionContext + + +def _is_async_redis_client(client: Any) -> bool: + return isinstance(client, (AsyncRedis, AsyncRedisCluster)) + + +async def _redis_call(client: Any, method_name: str, *args: Any, **kwargs: Any) -> Any: + """Call a Redis method without blocking the event loop. + + - For `redis.asyncio` clients, calls are awaited directly. + - For sync `redis.Redis`, calls are offloaded via `asyncio.to_thread()`. + """ + + method = getattr(client, method_name) + if _is_async_redis_client(client): + return await method(*args, **kwargs) + return await asyncio.to_thread(method, *args, **kwargs) + + +def _pydantic_deep_copy(model: Any) -> Any: + """Deep copy helper compatible with Pydantic v1 and v2.""" + + if hasattr(model, "model_copy"): + return model.model_copy(deep=True) + return model.copy(deep=True) + + +class FeedResultNextPageInside(BaseModel): + """Cursor model for one feed node.""" + + page: int = 1 + after: Any = None + + +class FeedResultNextPage(BaseModel): + """Cursor model for a whole feed traversal.""" + + data: Dict[str, FeedResultNextPageInside] + + +class FeedResult(BaseModel): + """Normalized output of any feed node `get_data()`.""" + + data: List + next_page: FeedResultNextPage + has_next_page: bool + + +class FeedResultClient(BaseModel): + """Result returned by client subfeed methods.""" + + data: List + next_page: FeedResultNextPageInside + has_next_page: bool + + +class BaseFeedConfigModel(BaseModel): + """Base class for merger/subfeed config models.""" + + # Higher value means the item should "win" deduplication when duplicates exist. + # This is primarily used by MergerDeduplication and by mergers when a dedup wrapper is active. + dedup_priority: int = 0 + + async def get_data( + self, + methods_dict: Dict[str, Callable], + user_id: Any, + limit: int, + next_page: FeedResultNextPage, + redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, + ctx: Optional["ExecutionContext"] = None, + **params: Any, + ) -> FeedResult: + """Default merger execution path via the shared executor.""" + + if not callable(getattr(self, "build_plan", None)): + raise NotImplementedError( + f"{self.__class__.__name__} must implement build_plan(...) or override get_data(...)." + ) + + if ctx is None: + from .execution.context import ExecutionContext as _ExecutionContext + + ctx = _ExecutionContext(methods_dict=methods_dict, user_id=user_id, redis_client=redis_client) + else: + ctx.ensure_redis_client(redis_client) + + executor = ctx.ensure_executor() + return await executor.run(self, ctx, limit, next_page, **params) + + +@dataclass +class _SubFeedMethodSpec: + method: Callable + args: List[str] + + +class SubFeed(BaseFeedConfigModel): + """Leaf node pointing at a client method.""" + + subfeed_id: str + type: Literal["subfeed"] + method_name: str + subfeed_params: Dict[str, Any] = {} + raise_error: Optional[bool] = True + shuffle: bool = False + + def _get_method_spec(self, methods_dict: Dict[str, Callable]) -> _SubFeedMethodSpec: + method = methods_dict[self.method_name] + method_spec = getattr(method, "_smartfeed_original", method) + method_args = inspect.getfullargspec(method_spec).args + return _SubFeedMethodSpec(method=method, args=list(method_args)) + + async def get_data( + self, + methods_dict: Dict[str, Callable], + user_id: Any, + limit: int, + next_page: FeedResultNextPage, + redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, + ctx: Optional["ExecutionContext"] = None, + **params: Any, + ) -> FeedResult: + if ctx is None: + from .execution.context import ExecutionContext as _ExecutionContext + + ctx = _ExecutionContext(methods_dict=methods_dict, user_id=user_id, redis_client=redis_client) + + subfeed_next_page = FeedResultNextPageInside( + page=next_page.data[self.subfeed_id].page if self.subfeed_id in next_page.data else 1, + after=next_page.data[self.subfeed_id].after if self.subfeed_id in next_page.data else None, + ) + + method_spec = self._get_method_spec(methods_dict) + + method_params: Dict[str, Any] = {} + for arg in method_spec.args: + if arg in params: + method_params[arg] = params[arg] + + method = method_spec.method + is_async = inspect.iscoroutinefunction(method) or inspect.iscoroutinefunction(getattr(method, "__call__", None)) + + try: + if is_async: + method_result = await method( + user_id=user_id, + limit=limit, + next_page=subfeed_next_page, + **method_params, + **self.subfeed_params, + ) + else: + method_result = await asyncio.to_thread( + method, + user_id=user_id, + limit=limit, + next_page=subfeed_next_page, + **method_params, + **self.subfeed_params, + ) + except Exception: + if self.raise_error: + raise + + method_result = FeedResultClient( + data=[], + next_page=subfeed_next_page, + has_next_page=False, + ) + + if not isinstance(method_result, FeedResultClient): + raise TypeError('SubFeed function must return "FeedResultClient" instance.') + + if self.shuffle: + shuffle(method_result.data) + + return FeedResult( + data=method_result.data, + next_page=FeedResultNextPage( + data={self.subfeed_id: cast(FeedResultNextPageInside, method_result.next_page)} + ), + has_next_page=bool(method_result.has_next_page), + ) diff --git a/smartfeed/jsonlib.py b/smartfeed/jsonlib.py new file mode 100644 index 0000000..9ab650f --- /dev/null +++ b/smartfeed/jsonlib.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional + +import orjson + +DefaultFn = Callable[[Any], Any] + + +def dumps( + obj: Any, + *, + default: Optional[DefaultFn] = None, + sort_keys: bool = False, +) -> str: + """Serialize *obj* to JSON text using orjson. + + This is a small compatibility layer meant to cover the subset of the stdlib + `json.dumps` API used inside this package. + + Key differences vs `orjson.dumps`: + - Returns `str` (UTF-8) instead of `bytes`. + - Supports `default=` and `sort_keys=`. + """ + + option = 0 + if sort_keys: + option |= orjson.OPT_SORT_KEYS + + return orjson.dumps(obj, default=default, option=option).decode("utf-8") + + +def loads(data: Any) -> Any: + """Deserialize JSON from *data* using orjson. + + Accepts `str`, `bytes`, `bytearray`, or `memoryview` (same spirit as + stdlib `json.loads`). + """ + + return orjson.loads(data) diff --git a/smartfeed/manager.py b/smartfeed/manager.py index e91bbe9..d6c7a76 100644 --- a/smartfeed/manager.py +++ b/smartfeed/manager.py @@ -3,6 +3,8 @@ import redis from redis.asyncio import Redis as AsyncRedis +from .execution.context import ExecutionContext +from .pydantic_compat import parse_model from .schemas import FeedConfig, FeedResult, FeedResultNextPage @@ -20,7 +22,7 @@ def __init__(self, config: Dict, methods_dict: Dict, redis_client: Optional[Unio :param redis_client: объект клиента Redis (для конфигурации с view_session = True). """ - self.feed_config = FeedConfig.parse_obj(config) + self.feed_config = parse_model(FeedConfig, config) self.methods_dict = methods_dict self.redis_client = redis_client @@ -35,12 +37,6 @@ async def get_data(self, user_id: Any, limit: int, next_page: FeedResultNextPage :return: результат получения данных согласно конфигурации фида. """ - result = await self.feed_config.feed.get_data( - methods_dict=self.methods_dict, - user_id=user_id, - limit=limit, - next_page=next_page, - redis_client=self.redis_client, - **params, - ) - return result + ctx = ExecutionContext(methods_dict=self.methods_dict, user_id=user_id, redis_client=self.redis_client) + executor = ctx.ensure_executor() + return await executor.run(self.feed_config.feed, ctx, limit, next_page, **params) diff --git a/smartfeed/mergers/__init__.py b/smartfeed/mergers/__init__.py new file mode 100644 index 0000000..24b1c8e --- /dev/null +++ b/smartfeed/mergers/__init__.py @@ -0,0 +1,24 @@ +"""Merger implementations. + +Each merger schema lives in its own module. +`smartfeed.schemas` re-exports these classes for backwards compatibility. +""" + +from .append import MergerAppend +from .append_distribute import MergerAppendDistribute +from .deduplication import MergerDeduplication +from .percentage import MergerPercentage, MergerPercentageItem +from .percentage_gradient import MergerPercentageGradient +from .positional import MergerPositional +from .view_session import MergerViewSession + +__all__ = [ + "MergerAppend", + "MergerAppendDistribute", + "MergerDeduplication", + "MergerPercentage", + "MergerPercentageItem", + "MergerPercentageGradient", + "MergerPositional", + "MergerViewSession", +] diff --git a/smartfeed/mergers/append.py b/smartfeed/mergers/append.py new file mode 100644 index 0000000..9c5c5c6 --- /dev/null +++ b/smartfeed/mergers/append.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from random import shuffle +from typing import TYPE_CHECKING, Any, Dict, List, Literal, cast + +from ..execution.context import ExecutionContext +from ..execution.executor import SlotSpec, SlotsPlan +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage + +if TYPE_CHECKING: + from ..schemas import FeedTypes + + +class MergerAppend(BaseFeedConfigModel): + """Append merger.""" + + merger_id: str + type: Literal["merger_append"] + items: List[FeedTypes] + shuffle: bool = False + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> SlotsPlan: + slots = [SlotSpec(owner=cast(BaseFeedConfigModel, item), max_count=limit) for item in self.items] + + def _assemble( + output: List[Any], merged_next_page: FeedResultNextPage, owner_results: Dict[int, FeedResult] + ) -> FeedResult: + has_next_page = any(r.has_next_page for r in owner_results.values()) + result = FeedResult(data=output, next_page=merged_next_page, has_next_page=has_next_page) + if self.shuffle: + shuffle(result.data) + return result + + return SlotsPlan( + ctx=ctx, + limit=limit, + next_page=next_page, + params=dict(params), + slots=slots, + assemble=_assemble, + ) diff --git a/smartfeed/mergers/append_distribute.py b/smartfeed/mergers/append_distribute.py new file mode 100644 index 0000000..220e3e3 --- /dev/null +++ b/smartfeed/mergers/append_distribute.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from collections import defaultdict, deque +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional + +from typing_extensions import no_type_check + +from ..execution.context import ExecutionContext +from ..execution.executor import SlotSpec, SlotsPlan +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage + +if TYPE_CHECKING: + from ..schemas import FeedTypes + + +class MergerAppendDistribute(BaseFeedConfigModel): + """Merger that uniformly distributes items by a key.""" + + merger_id: str + type: Literal["merger_distribute"] + items: List["FeedTypes"] + distribution_key: str + sorting_key: Optional[str] = None + sorting_desc: bool = False + + @no_type_check + def _uniform_distribute(self, data: list) -> list: + if self.sorting_key: + data = sorted(data, key=lambda x: x[self.sorting_key], reverse=self.sorting_desc) + + grouped_entries = defaultdict(deque) + for entry in data: + grouped_entries[entry[self.distribution_key]].append(entry) + result = [] + prev_profile_id = None + while any(grouped_entries.values()): + for profile_id in list(grouped_entries.keys()): + if grouped_entries[profile_id]: + if profile_id != prev_profile_id or len(grouped_entries) == 1: + result.append(grouped_entries[profile_id].popleft()) + prev_profile_id = profile_id + if not grouped_entries[profile_id]: + del grouped_entries[profile_id] + else: + del grouped_entries[profile_id] + + return result + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> SlotsPlan: + slots = [SlotSpec(owner=item, max_count=limit) for item in self.items] + + def _assemble( + output: List[Any], merged_next_page: FeedResultNextPage, owner_results: Dict[int, FeedResult] + ) -> FeedResult: + has_next_page = any(r.has_next_page for r in owner_results.values()) + distributed = self._uniform_distribute(output) + return FeedResult(data=distributed, next_page=merged_next_page, has_next_page=has_next_page) + + return SlotsPlan( + ctx=ctx, + limit=limit, + next_page=next_page, + params=dict(params), + slots=slots, + assemble=_assemble, + ) diff --git a/smartfeed/mergers/deduplication.py b/smartfeed/mergers/deduplication.py new file mode 100644 index 0000000..1b35a7b --- /dev/null +++ b/smartfeed/mergers/deduplication.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional + +from pydantic import PrivateAttr, model_validator + +from ..execution.context import ExecutionContext, RefillExecutionSettings +from ..execution.cursors import CursorMap +from ..execution.executor import CallablePlan +from ..feed_models import ( + BaseFeedConfigModel, + FeedResult, + FeedResultNextPage, + FeedResultNextPageInside, + _pydantic_deep_copy, +) +from ..policies.dedup import DeduplicationPolicy +from ..policies.seen_store import CursorSeenStore, RedisSeenStore, SeenStore + +if TYPE_CHECKING: + from ..schemas import FeedTypes + + +class MergerDeduplication(BaseFeedConfigModel): + """Merger that deduplicates while preserving child mixing/position semantics.""" + + merger_id: str + type: Literal["merger_deduplication"] + data: "FeedTypes" + + dedup_key: Optional[str] = None + missing_key_policy: Literal["error", "keep", "drop"] = "error" + + state_backend: Literal["cursor", "redis"] = "cursor" + state_ttl_seconds: int = 3600 + cursor_compress: bool = True + cursor_max_keys: Optional[int] = None + + overfetch_factor: int = 1 + + max_refill_loops: int = 20 + + _descendant_cursor_keys_cache: Optional[set[str]] = PrivateAttr(default=None) + + @model_validator(mode="after") + def validate_merger_deduplication(self) -> "MergerDeduplication": + if self.overfetch_factor < 1: + raise ValueError('"overfetch_factor" must be >= 1') + if self.max_refill_loops < 1: + raise ValueError('"max_refill_loops" must be >= 1') + return self + + def _collect_descendant_cursor_keys(self, feed: BaseFeedConfigModel) -> set[str]: + keys: set[str] = set() + stack = [feed] + while stack: + node = stack.pop() + + for attr in ("subfeed_id", "merger_id"): + value = getattr(node, attr, None) + if isinstance(value, str) and value: + keys.add(value) + + for child in ( + getattr(node, "data", None), + getattr(node, "positional", None), + getattr(node, "default", None), + ): + if isinstance(child, BaseFeedConfigModel): + stack.append(child) + + for wrapper in (getattr(node, "item_from", None), getattr(node, "item_to", None)): + inner = getattr(wrapper, "data", None) + if isinstance(inner, BaseFeedConfigModel): + stack.append(inner) + + items = getattr(node, "items", None) + if isinstance(items, list): + for item in items: + inner = item if isinstance(item, BaseFeedConfigModel) else getattr(item, "data", None) + if isinstance(inner, BaseFeedConfigModel): + stack.append(inner) + + return keys + + def _get_descendant_cursor_keys_cached(self) -> set[str]: + cached = self._descendant_cursor_keys_cache + if cached is None: + cached = self._collect_descendant_cursor_keys(self.data) + self._descendant_cursor_keys_cache = cached + return cached + + def _reset_descendant_cursors(self, next_page: FeedResultNextPage) -> None: + descendant_keys = self._get_descendant_cursor_keys_cached() + CursorMap(next_page).reset_keys(descendant_keys) + + def _build_redis_state_key(self, user_id: Any, params: Dict[str, Any]) -> str: + suffix = params.get("custom_deduplication_key") or params.get("custom_view_session_key") + if suffix: + return f"dedup:{self.merger_id}:{user_id}:{suffix}" + return f"dedup:{self.merger_id}:{user_id}" + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> CallablePlan: + async def _run(executor: Any) -> FeedResult: + if limit <= 0: + return FeedResult(data=[], next_page=next_page, has_next_page=False) + + if ctx.executor is None: + ctx.executor = executor + + entry = next_page.data.get(self.merger_id) + requested_page = entry.page if entry is not None else None + is_fresh_session = requested_page is None or (isinstance(requested_page, int) and requested_page <= 0) + + redis_client = ctx.redis_client + if self.state_backend == "redis" and not redis_client: + raise ValueError("Redis client must be provided if using MergerDeduplication with state_backend=redis") + + working_next_page = _pydantic_deep_copy(next_page) + if is_fresh_session: + self._reset_descendant_cursors(working_next_page) + + seen_request_set: set[str] = set() + store: SeenStore + if self.state_backend == "cursor": + cursor_entry = next_page.data.get(self.merger_id) + store = CursorSeenStore.from_after( + cursor_entry.after if (cursor_entry is not None and not is_fresh_session) else None, + cursor_compress=self.cursor_compress, + cursor_max_keys=self.cursor_max_keys, + ) + else: + assert redis_client is not None + redis_state_key = self._build_redis_state_key(user_id=ctx.user_id, params=params) + store = RedisSeenStore.create( + redis_client=redis_client, + redis_key=redis_state_key, + ttl_seconds=self.state_ttl_seconds, + ) + if is_fresh_session: + await store.reset() + + policy = DeduplicationPolicy( + dedup_key=self.dedup_key, + missing_key_policy=self.missing_key_policy, + store=store, + seen_request_set=seen_request_set, + ) + + refill_settings = RefillExecutionSettings( + overfetch_factor=self.overfetch_factor, + max_refill_loops=self.max_refill_loops, + ) + child_ctx = ExecutionContext( + methods_dict=ctx.methods_dict, + user_id=ctx.user_id, + redis_client=ctx.redis_client, + executor=ctx.executor, + dedup=policy, + refill_settings=refill_settings, + ) + + child = _pydantic_deep_copy(self.data) + child_result = await executor.run(child, child_ctx, limit, working_next_page, **params) + + commit_result: Any = await store.commit() + merger_after: Any = commit_result if self.state_backend == "cursor" else None + + page = next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1 + result_next_page = _pydantic_deep_copy(child_result.next_page) + result_next_page.data[self.merger_id] = FeedResultNextPageInside(page=page + 1, after=merger_after) + return FeedResult( + data=child_result.data, + next_page=result_next_page, + has_next_page=child_result.has_next_page, + ) + + return CallablePlan(fn=_run) diff --git a/smartfeed/mergers/percentage.py b/smartfeed/mergers/percentage.py new file mode 100644 index 0000000..9ea2096 --- /dev/null +++ b/smartfeed/mergers/percentage.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from random import shuffle +from typing import TYPE_CHECKING, Any, Dict, List, Literal, cast + +from pydantic import BaseModel + +from ..execution.context import ExecutionContext +from ..execution.executor import SlotSpec, SlotsPlan +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage + +if TYPE_CHECKING: + from ..schemas import FeedTypes + + +class MergerPercentageItem(BaseModel): + """One percentage slot.""" + + percentage: int + data: FeedTypes + + +class MergerPercentage(BaseFeedConfigModel): + """Percentage-based mixing merger.""" + + merger_id: str + type: Literal["merger_percentage"] + items: List[MergerPercentageItem] + shuffle: bool = False + + @staticmethod + def _merge_items_data(items_data: List[List]) -> List: + result: List = [] + cursor: List[Dict] = [] + + min_length = min(len(item_data) for item_data in items_data) or 1 + for item_data in items_data: + cursor.append( + { + "items": item_data, + "current": 0, + "size": round(len(item_data) / min_length), + } + ) + + full_length = sum(len(item_data) for item_data in items_data) + while len(result) < full_length: + for item_cursor in cursor: + items = item_cursor["items"] + start = item_cursor["current"] + end = start + item_cursor["size"] if start + item_cursor["size"] < len(items) else len(items) + result.extend(items[start:end]) + item_cursor["current"] = end + + return result + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> SlotsPlan: + owners: List[BaseFeedConfigModel] = [cast(BaseFeedConfigModel, item.data) for item in self.items] + + slot_limits: List[int] = [] + remainders: List[tuple[int, int]] = [] + total_percentage = sum(int(item.percentage) for item in self.items) + + for idx, item in enumerate(self.items): + raw = int(limit) * int(item.percentage) + child_limit = raw // 100 + slot_limits.append(max(0, child_limit)) + remainders.append((raw % 100, idx)) + + # avoid underfilling for the common "percentages sum to 100" case + if total_percentage == 100: + missing = max(0, int(limit) - sum(slot_limits)) + if missing > 0: + for _rem, idx in sorted(remainders, key=lambda x: (-x[0], x[1])): + if missing <= 0: + break + slot_limits[idx] += 1 + missing -= 1 + + slots: List[SlotSpec] = [ + SlotSpec(owner=owner, max_count=max(0, int(slot_limits[idx]))) for idx, owner in enumerate(owners) + ] + + def _assemble( + output: List[Any], + merged_next_page: FeedResultNextPage, + owner_results: Dict[int, FeedResult], + ) -> FeedResult: + items_data: List[List[Any]] = [] + has_next_page = False + + for owner in owners: + child_res = owner_results.get(id(owner)) + if child_res is None: + items_data.append([]) + continue + items_data.append(list(child_res.data)) + has_next_page = has_next_page or bool(child_res.has_next_page) + + data = self._merge_items_data(items_data=items_data) + if self.shuffle: + shuffle(data) + + return FeedResult(data=data, next_page=merged_next_page, has_next_page=has_next_page) + + return SlotsPlan( + ctx=ctx, + limit=limit, + next_page=next_page, + params=dict(params), + slots=slots, + assemble=_assemble, + ) diff --git a/smartfeed/mergers/percentage_gradient.py b/smartfeed/mergers/percentage_gradient.py new file mode 100644 index 0000000..fb70891 --- /dev/null +++ b/smartfeed/mergers/percentage_gradient.py @@ -0,0 +1,155 @@ +from random import shuffle +from typing import Any, Dict, List, Literal, cast + +from pydantic import model_validator + +from ..execution.context import ExecutionContext +from ..execution.executor import SlotSpec, SlotsPlan +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage, FeedResultNextPageInside +from .percentage import MergerPercentageItem + + +class MergerPercentageGradient(BaseFeedConfigModel): + """Percentage-gradient merger.""" + + merger_id: str + type: Literal["merger_percentage_gradient"] + item_from: MergerPercentageItem + item_to: MergerPercentageItem + step: int + size_to_step: int + shuffle: bool = False + + @model_validator(mode="after") + def validate_merger_percentage_gradient(self) -> "MergerPercentageGradient": + if self.step < 1 or self.step > 100: + raise ValueError('"step" must be in range from 1 to 100') + if self.size_to_step < 1: + raise ValueError('"size_to_step" must be bigger than 1') + return self + + def _calculate_limits_and_percents(self, page: int, limit: int) -> Dict: + result: Dict = { + "limit_from": 0, + "limit_to": 0, + "percentages": [], + } + + percentage_from = self.item_from.percentage + percentage_to = self.item_to.percentage + start_position = limit * (page - 1) + first_iter = True + + for i in range(self.size_to_step, limit * page + self.size_to_step, self.size_to_step): + if not first_iter and percentage_to < 100: + percentage_from -= self.step + percentage_to += self.step + + if percentage_to > 100 or percentage_from < 0: + percentage_from = 0 + percentage_to = 100 + + if i > start_position: + iter_limit = (limit * page - start_position) if i > limit * page else (i - start_position) + start_position = i + + if result["percentages"] and result["percentages"][-1]["to"] >= 100: + result["limit_to"] += iter_limit + result["percentages"][-1]["limit"] += iter_limit + result["percentages"][-1]["to_take"] += iter_limit + else: + from_take = iter_limit * percentage_from // 100 + to_take = iter_limit - from_take + result["limit_from"] += from_take + result["limit_to"] += to_take + iter_result = { + "limit": iter_limit, + "from": percentage_from, + "to": percentage_to, + "from_take": from_take, + "to_take": to_take, + } + result["percentages"].append(iter_result) + + if first_iter: + first_iter = False + + return result + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> SlotsPlan: + start_page = next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1 + start_after = next_page.data[self.merger_id].after if self.merger_id in next_page.data else None + + plan_next_page = FeedResultNextPage( + data={ + **next_page.data, + self.merger_id: FeedResultNextPageInside(page=start_page, after=start_after), + } + ) + + limits_and_percents = self._calculate_limits_and_percents(page=start_page, limit=limit) + + owner_from = cast(BaseFeedConfigModel, self.item_from.data) + owner_to = cast(BaseFeedConfigModel, self.item_to.data) + + slots = [ + SlotSpec(owner=owner_from, max_count=int(limits_and_percents["limit_from"])), + SlotSpec(owner=owner_to, max_count=int(limits_and_percents["limit_to"])), + ] + + def _assemble( + output: List[Any], + merged_next_page: FeedResultNextPage, + owner_results: Dict[int, FeedResult], + ) -> FeedResult: + from_res = owner_results.get(id(owner_from)) + to_res = owner_results.get(id(owner_to)) + + from_data = list(from_res.data) if from_res is not None else [] + to_data = list(to_res.data) if to_res is not None else [] + + data: List[Any] = [] + from_start_index = 0 + to_start_index = 0 + for lp_data in limits_and_percents["percentages"]: + from_take = int(lp_data.get("from_take", lp_data["limit"] * lp_data["from"] // 100)) + to_take = int(lp_data.get("to_take", lp_data["limit"] - from_take)) + + from_end_index = from_start_index + from_take + to_end_index = to_start_index + to_take + + data.extend(from_data[from_start_index:from_end_index]) + data.extend(to_data[to_start_index:to_end_index]) + + from_start_index = from_end_index + to_start_index = to_end_index + + has_next_page = False + if from_res is not None and from_res.has_next_page: + has_next_page = True + if to_res is not None and to_res.has_next_page: + has_next_page = True + + if self.shuffle: + shuffle(data) + + if self.merger_id in merged_next_page.data: + merged_next_page.data[self.merger_id].page += 1 + + return FeedResult(data=data, next_page=merged_next_page, has_next_page=has_next_page) + + return SlotsPlan( + ctx=ctx, + limit=limit, + next_page=plan_next_page, + params=dict(params), + slots=slots, + assemble=_assemble, + ) diff --git a/smartfeed/mergers/positional.py b/smartfeed/mergers/positional.py new file mode 100644 index 0000000..6023aae --- /dev/null +++ b/smartfeed/mergers/positional.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional + +from pydantic import model_validator + +from ..execution.context import ExecutionContext +from ..execution.executor import SlotSpec, SlotsPlan +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage, FeedResultNextPageInside + +if TYPE_CHECKING: + from ..schemas import FeedTypes + + +class MergerPositional(BaseFeedConfigModel): + """Positional merger.""" + + merger_id: str + type: Literal["merger_positional"] + positions: List[int] = [] + start: Optional[int] = None + end: Optional[int] = None + step: Optional[int] = None + positional: FeedTypes + default: FeedTypes + + @model_validator(mode="after") + def validate_merger_positional(self) -> "MergerPositional": + if not self.positions and not all((self.start, self.end, self.step)): + raise ValueError('Either "positions" or "start", "end", and "step" must be provided') + if self.start and self.positions: + if isinstance(self.start, int) and self.start <= max(self.positions): + raise ValueError('"start" must be bigger than maximum value of "positions"') + if isinstance(self.start, int) and isinstance(self.end, int): + if self.end <= self.start: + raise ValueError('"end" must be bigger than "start"') + return self + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> SlotsPlan: + page = next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1 + + positional_has_next_page = True + page_positions: List[int] = [] + available_positions = range((page - 1) * limit, (page * limit) + 1) + for position in self.positions: + if position in available_positions: + page_positions.append(available_positions.index(position)) + + if max(available_positions) >= max(self.positions, default=0): + positional_has_next_page = False + + if self.start is not None and self.end is not None and self.step is not None: + positional_has_next_page = not max(available_positions) >= self.end + + for position in range(self.start, self.end, self.step): + if position in available_positions: + page_positions.append(available_positions.index(position)) + + pos_limit = len(page_positions) + + # Build a slot ownership schedule by applying the same sequential insert + # semantics as the legacy assembly logic. + schedule: List[BaseFeedConfigModel] = [self.default for _ in range(limit)] + for insert_index in [p - 1 for p in page_positions[:pos_limit]]: + schedule.insert(insert_index, self.positional) + schedule = schedule[:limit] + + # Compress the schedule into contiguous segments. + slots: List[SlotSpec] = [] + if schedule: + current_owner = schedule[0] + count = 1 + for owner in schedule[1:]: + if owner is current_owner: + count += 1 + continue + slots.append(SlotSpec(owner=current_owner, max_count=count)) + current_owner = owner + count = 1 + slots.append(SlotSpec(owner=current_owner, max_count=count)) + + after = next_page.data[self.merger_id].after if self.merger_id in next_page.data else None + + def _assemble( + output: List[Any], merged_next_page: FeedResultNextPage, owner_results: Dict[int, FeedResult] + ) -> FeedResult: + default_res = owner_results.get(id(self.default)) + pos_res = owner_results.get(id(self.positional)) + + has_next_page = bool(default_res.has_next_page) if default_res is not None else False + if not has_next_page and positional_has_next_page and pos_res is not None and pos_res.has_next_page: + has_next_page = True + + result_next_page = merged_next_page + result_next_page.data[self.merger_id] = FeedResultNextPageInside(page=page + 1, after=after) + return FeedResult(data=output, next_page=result_next_page, has_next_page=has_next_page) + + return SlotsPlan( + ctx=ctx, + limit=limit, + next_page=next_page, + params=dict(params), + slots=slots, + owner_fetch_limits={ + id(self.default): limit, + id(self.positional): pos_limit, + }, + assemble=_assemble, + ) diff --git a/smartfeed/mergers/view_session.py b/smartfeed/mergers/view_session.py new file mode 100644 index 0000000..dad3314 --- /dev/null +++ b/smartfeed/mergers/view_session.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import logging +from random import shuffle +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + +import redis +from redis.asyncio import Redis as AsyncRedis + +from .. import jsonlib as json +from ..execution.context import ExecutionContext +from ..execution.executor import CallablePlan +from ..feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage, FeedResultNextPageInside, _redis_call +from ..policies.dedup import entity_key + +if TYPE_CHECKING: + from ..schemas import FeedTypes + + +class MergerViewSession(BaseFeedConfigModel): + """Merger with view-session caching.""" + + merger_id: str + type: Literal["merger_view_session"] + session_size: int + session_live_time: int + data: "FeedTypes" + deduplicate: bool = False + dedup_key: Optional[str] = None + missing_key_policy: Literal["error", "keep", "drop"] = "error" + shuffle: bool = False + + def _get_dedup_key_or_attr(self, item: Any) -> str: + key = entity_key(item, self.dedup_key, self.missing_key_policy) + assert key is not None, "Deduplication key is missing and item was dropped by missing_key_policy='drop'" + return key + + def _dedup_data(self, data: List[Any]) -> List[Any]: + deduplicated: List[Any] = [] + seen: set[str] = set() + for item in data: + key = entity_key(item, self.dedup_key, self.missing_key_policy) + if key is None: + continue + if key in seen: + continue + seen.add(key) + deduplicated.append(item) + return deduplicated + + async def _set_cache( + self, + redis_client: Union[redis.Redis, AsyncRedis], + cache_key: str, + ctx: ExecutionContext, + **params: Any, + ) -> List[Any]: + if ctx.executor is None: + raise ValueError("Executor must be initialized for MergerViewSession") + + result = await ctx.executor.run(self.data, ctx, self.session_size, FeedResultNextPage(data={}), **params) + + data = result.data + if self.deduplicate: + data = self._dedup_data(data) + await _redis_call(redis_client, "set", cache_key, json.dumps(data), ex=self.session_live_time) + return data + + async def _get_cache( + self, + limit: int, + next_page: FeedResultNextPage, + redis_client: Union[redis.Redis, AsyncRedis], + ctx: ExecutionContext, + **params: Any, + ) -> FeedResult: + cache_key = ( + f"{self.merger_id}_{ctx.user_id}_{session_cache_key}" + if (session_cache_key := params.get("custom_view_session_key")) + else f"{self.merger_id}_{ctx.user_id}" + ) + + logging.info("MergerViewSession cache request for %s", cache_key) + cache_exists = bool(await _redis_call(redis_client, "exists", cache_key)) + if not cache_exists or self.merger_id not in next_page.data: + logging.info("Cache miss or new session - generating fresh data for %s", cache_key) + session_data = await self._set_cache( + redis_client=redis_client, + cache_key=cache_key, + ctx=ctx, + **params, + ) + else: + logging.info("Cache exists - attempting read from Redis for %s", cache_key) + cached_data = await _redis_call(redis_client, "get", cache_key) + if cached_data is None: + logging.info( + "Redis returned None for %s - falling back to fresh data (cluster replication issue)", cache_key + ) + session_data = await self._set_cache( + redis_client=redis_client, + cache_key=cache_key, + ctx=ctx, + **params, + ) + else: + logging.info("Successfully read cached data for %s", cache_key) + session_data = json.loads(cached_data) + + page = next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1 + return FeedResult( + data=session_data[(page - 1) * limit :][:limit], + next_page=FeedResultNextPage(data={self.merger_id: FeedResultNextPageInside(page=page + 1, after=None)}), + has_next_page=bool(len(session_data) > limit * page), + ) + + def build_plan( + self, + *, + ctx: ExecutionContext, + limit: int, + next_page: FeedResultNextPage, + **params: Any, + ) -> CallablePlan: + async def _run(executor: Any) -> FeedResult: + if ctx.redis_client is None: + raise ValueError("Redis client must be provided if using Merger View Session") + + if ctx.executor is None: + ctx.executor = executor + + result = await self._get_cache( + limit=limit, + next_page=next_page, + redis_client=ctx.redis_client, + ctx=ctx, + **params, + ) + + if self.shuffle: + shuffle(result.data) + return result + + return CallablePlan(fn=_run) diff --git a/smartfeed/policies/dedup.py b/smartfeed/policies/dedup.py new file mode 100644 index 0000000..b040c47 --- /dev/null +++ b/smartfeed/policies/dedup.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple + +from .. import jsonlib as json +from .seen_store import SeenStore + +MissingKeyPolicy = Literal["error", "keep", "drop"] + + +def normalize_key(value: Any) -> str: + if isinstance(value, (str, int)): + return str(value) + if isinstance(value, (dict, list)): + return json.dumps(value, sort_keys=True, default=str) + return str(value) + + +def extract_dedup_value(item: Any, dedup_key: Optional[str], missing_key_policy: MissingKeyPolicy) -> Any: + if not dedup_key: + return item + + try: + value = item.get(dedup_key) + except AttributeError: + value = getattr(item, dedup_key, None) + + if value is None and missing_key_policy == "error": + raise AssertionError(f"Deduplication failed: entity {item} has no key or attr {dedup_key}") + return value + + +def entity_key(item: Any, dedup_key: Optional[str], missing_key_policy: MissingKeyPolicy) -> Optional[str]: + raw_value = extract_dedup_value(item, dedup_key, missing_key_policy) + if raw_value is None: + if missing_key_policy == "drop": + return None + if missing_key_policy == "keep": + raw_value = ("__missing__", id(item)) + return normalize_key(raw_value) + + +@dataclass +class DeduplicationPolicy: + """Deduplication policy applied during execution. + + This keeps dedup logic out of merger implementations and plan interpreters. + """ + + dedup_key: Optional[str] + missing_key_policy: MissingKeyPolicy + + # Store backend (cursor or redis) + store: SeenStore + + # Keys encountered/accepted within this request. Prevents duplicates inside one response. + seen_request_set: set[str] + + def key_for(self, item: Any) -> Optional[str]: + return entity_key(item, self.dedup_key, self.missing_key_policy) + + async def prefetch_keys(self, keys: List[str]) -> None: + if not keys: + return + + filtered: List[str] = [] + seen: set[str] = set() + for k in keys: + if k in self.seen_request_set: + continue + if k in seen: + continue + seen.add(k) + filtered.append(k) + + if not filtered: + return + + await self.store.prefetch(filtered) + + def should_accept(self, key: str, priority: int) -> bool: + if key in self.seen_request_set: + return False + + existing_priority = self.store.get(key) + if existing_priority is not None and priority <= existing_priority: + return False + return True + + def record(self, key: str, priority: int) -> None: + self.seen_request_set.add(key) + self.store.set_max(key, priority) + + async def arbitrate_owner_buffers( + self, + *, + owners: List[Any], + owner_buffers: Dict[int, List[Any]], + owner_rank: Dict[int, int], + ) -> Dict[int, List[Any]]: + """Arbitrate winners across multiple owners. + + - Deterministic tie-break: (-priority, owner_rank, item_rank) + - Records winners into the store + - Returns per-owner buffers containing only accepted winners + """ + + keys_to_prefetch: List[str] = [] + keys_seen_local: set[str] = set() + for owner in owners: + owner_id = id(owner) + for item in owner_buffers.get(owner_id, []): + key = self.key_for(item) + if key is None: + continue + if key in keys_seen_local: + continue + keys_seen_local.add(key) + keys_to_prefetch.append(key) + + if keys_to_prefetch: + await self.prefetch_keys(keys_to_prefetch) + + winners: Dict[str, int] = {} + winner_prio: Dict[str, int] = {} + winner_tie: Dict[str, Tuple[int, int, int]] = {} + + for owner in owners: + owner_id = id(owner) + prio = int(getattr(owner, "dedup_priority", 0)) + rank = int(owner_rank.get(owner_id, 0)) + for item_rank, item in enumerate(owner_buffers.get(owner_id, [])): + key = self.key_for(item) + if key is None: + continue + tie = (-prio, rank, item_rank) + existing = winner_tie.get(key) + if existing is None or tie < existing: + winners[key] = owner_id + winner_prio[key] = prio + winner_tie[key] = tie + + for key, _tie in sorted(winner_tie.items(), key=lambda kv: kv[1]): + winner_owner_id = winners.get(key) + if winner_owner_id is None: + continue + prio = int(winner_prio.get(key, 0)) + if not self.should_accept(key, prio): + continue + self.record(key, prio) + + request_set = self.seen_request_set + per_owner_accepted: Dict[int, List[Any]] = {id(o): [] for o in owners} + for owner in owners: + owner_id = id(owner) + accepted: List[Any] = [] + for item in owner_buffers.get(owner_id, []): + key = self.key_for(item) + if key is None: + continue + if winners.get(key) != owner_id: + continue + if key not in request_set: + continue + accepted.append(item) + per_owner_accepted[owner_id] = accepted + + return per_owner_accepted + + async def accept_batch( + self, + *, + items: List[Any], + priority: int, + limit: Optional[int] = None, + ) -> Tuple[List[Any], int]: + """Accept items from a single stream in order. + + Returns accepted items and the number of inspected items. + """ + + if not items: + return [], 0 + + keys_to_prefetch: List[str] = [] + keys_seen_local: set[str] = set() + for item in items: + key = self.key_for(item) + if key is None: + continue + if key in self.seen_request_set: + continue + if key in keys_seen_local: + continue + keys_seen_local.add(key) + keys_to_prefetch.append(key) + + if keys_to_prefetch: + await self.prefetch_keys(keys_to_prefetch) + + accepted: List[Any] = [] + inspected_count = 0 + max_accept = int(limit) if limit is not None else len(items) + + for idx, item in enumerate(items, start=1): + inspected_count = idx + if len(accepted) >= max_accept: + break + key = self.key_for(item) + if key is None: + continue + if not self.should_accept(key, priority): + continue + accepted.append(item) + self.record(key, priority) + if len(accepted) >= max_accept: + break + + return accepted, inspected_count diff --git a/smartfeed/policies/dedup_utils.py b/smartfeed/policies/dedup_utils.py new file mode 100644 index 0000000..4483619 --- /dev/null +++ b/smartfeed/policies/dedup_utils.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio +import base64 +import zlib +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union, cast + +import redis +from redis.asyncio import Redis as AsyncRedis + +from .. import jsonlib as json +from ..feed_models import _is_async_redis_client, _redis_call + + +def _seen_entries_to_map(entries: Any) -> Dict[str, int]: + """Coerce a legacy cursor "seen" list into a {key: priority} map. + + Supports: + - ["k1", "k2", ...] (implies priority 0) + - [["k1", 10], ["k2", 3], ...] (explicit priorities) + """ + + seen_map: Dict[str, int] = {} + if not isinstance(entries, list): + return seen_map + + for entry_item in entries: + if isinstance(entry_item, (list, tuple)) and len(entry_item) == 2: + seen_map[str(entry_item[0])] = int(entry_item[1]) + else: + seen_map[str(entry_item)] = 0 + return seen_map + + +def decode_seen_from_cursor(after: Any) -> Dict[str, int]: + if after is None: + return {} + + if isinstance(after, dict) and "z" in after: + if after.get("v") != 2: + return {} + if after.get("c") != "zlib+base64": + return {} + payload = base64.urlsafe_b64decode(str(after["z"]).encode()) + raw = zlib.decompress(payload).decode() + decoded = json.loads(raw) + if isinstance(decoded, list): + return _seen_entries_to_map(decoded) + return {} + + if isinstance(after, dict) and "seen" in after: + if after.get("v") != 2: + return {} + return _seen_entries_to_map(list(after["seen"])) + + return {} + + +def encode_seen_for_cursor( + seen_updates_in_order: List[Tuple[str, int]], + *, + cursor_compress: bool, + cursor_max_keys: Optional[int], +) -> Any: + if cursor_max_keys is not None: + seen_updates_in_order = seen_updates_in_order[-cursor_max_keys:] + + if not cursor_compress: + return {"v": 2, "seen": [[k, p] for k, p in seen_updates_in_order]} + + raw = json.dumps([[k, p] for k, p in seen_updates_in_order]).encode() + compressed = zlib.compress(raw) + return { + "v": 2, + "c": "zlib+base64", + "n": len(seen_updates_in_order), + "z": base64.urlsafe_b64encode(compressed).decode(), + } + + +async def redis_zmscore( + redis_client: Union[redis.Redis, AsyncRedis], + key: str, + members: List[str], +) -> List[Optional[float]]: + if not members: + return [] + + if getattr(redis_client, "zmscore", None) is not None: + res = await _redis_call(redis_client, "zmscore", key, members) + return [None if v is None else float(v) for v in list(res)] + + if not _is_async_redis_client(redis_client): + + def _sync_pipeline_execute() -> Any: + pipe = redis_client.pipeline() + for m in members: + pipe.zscore(key, m) + return pipe.execute() + + res = await asyncio.to_thread(_sync_pipeline_execute) + return [None if v is None else float(v) for v in list(res)] + + pipe = redis_client.pipeline() + for m in members: + pipe.zscore(key, m) + res = await cast(Awaitable[Any], pipe.execute()) + return [None if v is None else float(v) for v in list(res)] + + +async def redis_zadd_and_expire( + redis_client: Union[redis.Redis, AsyncRedis], + key: str, + member_scores: Dict[str, int], + *, + ttl_seconds: int, +) -> None: + if not member_scores: + return + await _redis_call(redis_client, "zadd", key, mapping={m: float(s) for m, s in member_scores.items()}) + await _redis_call(redis_client, "expire", key, ttl_seconds) diff --git a/smartfeed/policies/seen_store.py b/smartfeed/policies/seen_store.py new file mode 100644 index 0000000..c4712c3 --- /dev/null +++ b/smartfeed/policies/seen_store.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union + +import redis +from redis.asyncio import Redis as AsyncRedis + +from ..feed_models import _redis_call +from .dedup_utils import decode_seen_from_cursor, encode_seen_for_cursor, redis_zadd_and_expire, redis_zmscore + + +class SeenStore(Protocol): + async def prefetch(self, keys: List[str]) -> None: + raise NotImplementedError + + def get(self, key: str) -> Optional[int]: + raise NotImplementedError + + def set_max(self, key: str, priority: int) -> None: + raise NotImplementedError + + async def reset(self) -> None: + raise NotImplementedError + + async def commit(self) -> Any: + raise NotImplementedError + + +@dataclass +class CursorSeenStore: + """Seen-store that persists in the cursor (`after`).""" + + cursor_compress: bool + cursor_max_keys: Optional[int] + + seen_priority_map: Dict[str, int] + seen_order: List[str] + + @classmethod + def from_after( + cls, + after: Any, + *, + cursor_compress: bool, + cursor_max_keys: Optional[int], + ) -> "CursorSeenStore": + seen_priority_map = decode_seen_from_cursor(after) + return cls( + cursor_compress=cursor_compress, + cursor_max_keys=cursor_max_keys, + seen_priority_map=seen_priority_map, + seen_order=list(seen_priority_map.keys()), + ) + + async def prefetch(self, keys: List[str]) -> None: + return None + + def get(self, key: str) -> Optional[int]: + return self.seen_priority_map.get(key) + + def set_max(self, key: str, priority: int) -> None: + existing = self.seen_priority_map.get(key) + if existing is not None and priority <= existing: + return + self.seen_priority_map[key] = priority + if key in self.seen_order: + self.seen_order.remove(key) + self.seen_order.append(key) + + async def reset(self) -> None: + self.seen_priority_map.clear() + self.seen_order.clear() + + async def commit(self) -> Any: + # Persist the full snapshot so dedup state survives beyond 2 pages. + seen_snapshot_in_order: List[Tuple[str, int]] = [ + (key, self.seen_priority_map[key]) for key in self.seen_order if key in self.seen_priority_map + ] + return encode_seen_for_cursor( + seen_snapshot_in_order, + cursor_compress=self.cursor_compress, + cursor_max_keys=self.cursor_max_keys, + ) + + +@dataclass +class RedisSeenStore: + """Seen-store backed by a Redis zset (member=key, score=priority).""" + + redis_client: Union[redis.Redis, AsyncRedis] + redis_key: str + ttl_seconds: int + + redis_seen_cache: Dict[str, Optional[int]] + redis_new_scores: Dict[str, int] + + @classmethod + def create( + cls, + *, + redis_client: Union[redis.Redis, AsyncRedis], + redis_key: str, + ttl_seconds: int, + ) -> "RedisSeenStore": + return cls( + redis_client=redis_client, + redis_key=redis_key, + ttl_seconds=ttl_seconds, + redis_seen_cache={}, + redis_new_scores={}, + ) + + async def prefetch(self, keys: List[str]) -> None: + if not keys: + return + + unique: List[str] = [] + seen: set[str] = set() + for k in keys: + if k in self.redis_seen_cache: + continue + if k in seen: + continue + seen.add(k) + unique.append(k) + + if not unique: + return + + scores = await redis_zmscore(self.redis_client, self.redis_key, unique) + + for k, s in zip(unique, scores): + self.redis_seen_cache[k] = None if s is None else int(s) + + def get(self, key: str) -> Optional[int]: + return self.redis_seen_cache.get(key) + + def set_max(self, key: str, priority: int) -> None: + existing = self.redis_seen_cache.get(key) + if existing is not None and priority <= existing: + return + self.redis_seen_cache[key] = priority + self.redis_new_scores[key] = max(self.redis_new_scores.get(key, 0), priority) + + async def reset(self) -> None: + await _redis_call(self.redis_client, "delete", self.redis_key) + self.redis_seen_cache.clear() + self.redis_new_scores.clear() + + async def commit(self) -> Any: + await redis_zadd_and_expire( + self.redis_client, + self.redis_key, + self.redis_new_scores, + ttl_seconds=self.ttl_seconds, + ) + self.redis_new_scores.clear() + return None diff --git a/smartfeed/pydantic_compat.py b/smartfeed/pydantic_compat.py new file mode 100644 index 0000000..185e004 --- /dev/null +++ b/smartfeed/pydantic_compat.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Any, Mapping, Type, TypeVar + +T = TypeVar("T") + + +def parse_model(model_cls: Type[T], obj: Mapping[str, Any]) -> T: + """Parse a mapping into a Pydantic model. + + Uses Pydantic v2 `model_validate` when available, otherwise falls back to v1 `parse_obj`. + """ + + if hasattr(model_cls, "model_validate"): + return model_cls.model_validate(obj) # type: ignore[attr-defined] + return model_cls.parse_obj(obj) # type: ignore[attr-defined] diff --git a/smartfeed/schemas.py b/smartfeed/schemas.py index 45df221..9863bb4 100644 --- a/smartfeed/schemas.py +++ b/smartfeed/schemas.py @@ -1,1132 +1,95 @@ -import inspect -import json -import logging -from abc import ABC, abstractmethod -from collections import defaultdict, deque -from random import shuffle -from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Union, no_type_check - -import redis -from pydantic import BaseModel, Field, root_validator -from redis.asyncio import Redis as AsyncRedis -from redis.asyncio import RedisCluster as AsyncRedisCluster +"""Public schema surface. + +This module keeps the public import path (`smartfeed.schemas`) stable while +moving merger implementations into `smartfeed.mergers.*`. +""" + +from __future__ import annotations + +from typing import Annotated, Any, Dict, Union + +from pydantic import BaseModel, Field + +from .feed_models import ( + BaseFeedConfigModel, + FeedResult, + FeedResultClient, + FeedResultNextPage, + FeedResultNextPageInside, + SubFeed, +) +from .mergers import ( + MergerAppend, + MergerAppendDistribute, + MergerDeduplication, + MergerPercentage, + MergerPercentageGradient, + MergerPercentageItem, + MergerPositional, + MergerViewSession, +) FeedTypes = Annotated[ Union[ - "MergerAppend", - "MergerAppendDistribute", - "MergerPositional", - "MergerPercentage", - "MergerPercentageGradient", - "MergerViewSession", - "SubFeed", + MergerDeduplication, + MergerAppend, + MergerAppendDistribute, + MergerPositional, + MergerPercentage, + MergerPercentageGradient, + MergerViewSession, + SubFeed, ], Field(discriminator="type"), ] -class FeedResultNextPageInside(BaseModel): - """ - Модель данных курсора пагинации конкретной позиции. - - Attributes: - page порядковый номер страницы. - after данные для пагинации клиентского метода. - """ - - page: int = 1 - after: Any = None - - -class FeedResultNextPage(BaseModel): - """ - Модель курсора пагинации. - - Attributes: - data словарь вида "ключ: данные по пагинации", где ключ - subfeed_id или merger_id. - """ - - data: Dict[str, FeedResultNextPageInside] - - -class FeedResult(BaseModel): - """ - Модель результата метода get_data() любой позиции / целого фида. - - Attributes: - data список данных, возвращенных мерджером / субфидом. - next_page курсор пагинации. - has_next_page флаг наличия следующей страницы данных. - """ - - data: List - next_page: FeedResultNextPage - has_next_page: bool - - -class FeedResultClient(BaseModel): - """ - Модель результата клиентского метода субфида. - - Attributes: - data список данных, возвращенных мерджером / субфидом. - next_page курсор пагинации клиентского метода. - has_next_page флаг наличия следующей страницы данных. - """ - - data: List - next_page: FeedResultNextPageInside - has_next_page: bool - - -class BaseFeedConfigModel(ABC, BaseModel): - """ - Абстрактный класс для мерджера / субфида конфигурации. - """ - - @abstractmethod - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: параметры для метода. - :return: список данных. - """ - - -class MergerViewSession(BaseFeedConfigModel): - """ - Модель мерджера с кэшированием. - - Attributes: - merger_id уникальный ID мерджера. - type тип объекта - всегда "merger_view_session". - view_session флаг использования механизма расчета всего фида сразу и сохранения в кэш. - session_size размер кэшируемого фида (limit получения данных для сохранения в кэш). - session_live_time срок хранения в кэше для кэшируемого фида (в секундах). - data мерджер или субфид. - deduplicate флаг дедупликации (удаления дублей из сессии). - dedup_key название ключа или атрибута, по которому логика дедпликации найдет дубли. - shuffle флаг для перемешивания полученных данных мерджера. - """ - - merger_id: str - type: Literal["merger_view_session"] - session_size: int - session_live_time: int - data: FeedTypes - deduplicate: bool = False - dedup_key: str = None # type: ignore - shuffle: bool = False - - def _get_dedup_key_or_attr(self, item: Any) -> str: - """ - Метод для получения ключа объекта кешируемой сессии. - - Если указанное в конфиге сессии название ключа имеет значение None, - в качестве ключа вернется сам объект. - Если название ключа не None, и для одного из объектов ни найден ни ключ, ни атрибут, - метод выбросит AssertionError. - - :param item: объект, для которого нужен ключ. - :return: ключ объекта. - """ - - if not self.dedup_key: - return item - - try: - dedup_value = item.get(self.dedup_key) - except AttributeError: - dedup_value = getattr(item, self.dedup_key, None) - - assert dedup_value is not None, f"Deduplication failed: entity {item} has no key or attr {self.dedup_key}" - return dedup_value - - def _dedup_data(self, data: List[Any]) -> List[Any]: - """ - Метод для удаления дублей в списке data с сохранением последовательности. - - :param data: список, в котором нужно удалить дубли. - :return: результат удаления дублей. - """ - - deduplicated_data = {self._get_dedup_key_or_attr(item): item for item in data} - return list(deduplicated_data.values()) - - async def _set_cache( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - redis_client: redis.Redis, - cache_key: str, - **params: Any, - ) -> List[Any]: - """ - Метод для кэширования данных Merger View Session. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param redis_client: объект клиента Redis. - :param cache_key: ключ для кэширования. - :param params: любые внешние параметры, передаваемые в исполняемую функцию на клиентской стороне. - :return: обработанные данные, которые были записаны в кэш. - """ - - result = await self.data.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=self.session_size, - next_page=FeedResultNextPage(data={}), - **params, - ) - - data = result.data - if self.deduplicate: - data = self._dedup_data(data) - redis_client.set(name=cache_key, value=json.dumps(data), ex=self.session_live_time) - return data - - async def _set_cache_async( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - redis_client: AsyncRedis, - cache_key: str, - **params: Any, - ) -> List[Any]: - """ - Метод для кэширования данных Merger View Session. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param redis_client: объект клиента Redis. - :param cache_key: ключ для кэширования. - :param params: любые внешние параметры, передаваемые в исполняемую функцию на клиентской стороне. - :return: обработанные данные, которые были записаны в кэш. - """ - - result = await self.data.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=self.session_size, - next_page=FeedResultNextPage(data={}), - **params, - ) - - data = result.data - if self.deduplicate: - data = self._dedup_data(data) - await redis_client.set(cache_key, json.dumps(data)) - await redis_client.expire(cache_key, self.session_live_time) - return data - - async def _get_cache( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: redis.Redis, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных Merger View Session из кэша Redis. - При отсутствии данных в кэше - получить и сохранить. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: лимит на выдачу данных. - :param next_page: курсор для пагинации в формате SmartFeedResultNextPage. - :param redis_client: объект клиента Redis. - :param params: любые внешние параметры, передаваемые в исполняемую функцию на клиентской стороне. - :return: результат получения данных согласно конфигурации фида. - """ - - # Формируем ключ для кэширования данных мерджера. - if session_cache_key := params.get("custom_view_session_key", None): - cache_key = f"{self.merger_id}_{user_id}_{session_cache_key}" - else: - cache_key = f"{self.merger_id}_{user_id}" - - logging.info("MergerViewSession cache request for %s", cache_key) - # Если кэш не найден или передан пустой курсор пагинации на мерджер, обновляем данные и записываем в кэш. - if not redis_client.exists(cache_key) or self.merger_id not in next_page.data: - logging.info("Cache miss or new session - generating fresh data for %s", cache_key) - # Получаем свежие данные и используем их напрямую (избегаем чтение из кэша) - session_data = await self._set_cache( - methods_dict=methods_dict, user_id=user_id, redis_client=redis_client, cache_key=cache_key, **params - ) - else: - logging.info("Cache exists - attempting read from Redis for %s", cache_key) - # Читаем из кэша только если он уже существовал - cached_data = redis_client.get(name=cache_key) - if cached_data is None: - # Fallback: если кэш пропал, получаем свежие данные - logging.info( - "Redis returned None for %s - falling back to fresh data (cluster replication issue)", cache_key - ) - session_data = await self._set_cache( - methods_dict=methods_dict, user_id=user_id, redis_client=redis_client, cache_key=cache_key, **params - ) - else: - logging.info("Successfully read cached data for %s", cache_key) - session_data = json.loads(cached_data) - page = next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1 - result = FeedResult( - data=session_data[(page - 1) * limit :][:limit], - next_page=FeedResultNextPage(data={self.merger_id: FeedResultNextPageInside(page=page + 1, after=None)}), - has_next_page=bool(len(session_data) > limit * page), - ) - return result - - async def _get_cache_async( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: AsyncRedis, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных Merger View Session из кэша Redis. - При отсутствии данных в кэше - получить и сохранить. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: лимит на выдачу данных. - :param next_page: курсор для пагинации в формате SmartFeedResultNextPage. - :param redis_client: объект клиента Redis. - :param params: любые внешние параметры, передаваемые в исполняемую функцию на клиентской стороне. - :return: результат получения данных согласно конфигурации фида. - """ - - # Формируем ключ для кэширования данных мерджера. - if session_cache_key := params.get("custom_view_session_key", None): - cache_key = f"{self.merger_id}_{user_id}_{session_cache_key}" - else: - cache_key = f"{self.merger_id}_{user_id}" - - # Если кэш не найден или передан пустой курсор пагинации на мерджер, обновляем данные и записываем в кэш. - if not await redis_client.exists(cache_key) or self.merger_id not in next_page.data: - # Получаем свежие данные и используем их напрямую (избегаем чтение из кэша) - session_data = await self._set_cache_async( - methods_dict=methods_dict, user_id=user_id, redis_client=redis_client, cache_key=cache_key, **params - ) - else: - # Читаем из кэша только если он уже существовал - cached_data = await redis_client.get(cache_key) - if cached_data is None: - # Fallback: если кэш пропал, получаем свежие данные - logging.info( - "Redis returned None for %s - falling back to fresh data (cluster replication issue)", cache_key - ) - session_data = await self._set_cache_async( - methods_dict=methods_dict, user_id=user_id, redis_client=redis_client, cache_key=cache_key, **params - ) - else: - logging.info("Successfully read cached data for %s", cache_key) - session_data = json.loads(cached_data) - page = next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1 - result = FeedResult( - data=session_data[(page - 1) * limit :][:limit], - next_page=FeedResultNextPage(data={self.merger_id: FeedResultNextPageInside(page=page + 1, after=None)}), - has_next_page=bool(len(session_data) > limit * page), - ) - return result - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных методом append. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: для метода класса. - :return: список данных методом append. - """ - - # Проверяем наличие клиента Redis в конфигурации фида. - if not redis_client: - raise ValueError("Redis client must be provided if using Merger View Session") - - # Формируем результат view session мерджера. - if isinstance(redis_client, (AsyncRedis, AsyncRedisCluster)): - result = await self._get_cache_async( - methods_dict=methods_dict, - user_id=user_id, - limit=limit, - next_page=next_page, - redis_client=redis_client, - **params, - ) - else: - result = await self._get_cache( - methods_dict=methods_dict, - user_id=user_id, - limit=limit, - next_page=next_page, - redis_client=redis_client, - **params, - ) - - # Если в конфигурации указано "смешать" данные. - if self.shuffle: - shuffle(result.data) - - return result - - -class MergerAppend(BaseFeedConfigModel): - """ - Модель append мерджера. - - Attributes: - merger_id уникальный ID мерджера. - type тип объекта - всегда "merger_append". - items позиции мерджера. - shuffle флаг для перемешивания полученных данных мерджера. - """ - - merger_id: str - type: Literal["merger_append"] - items: List[FeedTypes] - shuffle: bool = False - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных методом append. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: для метода класса. - :return: список данных методом append. - """ - - # Формируем результат append мерджера. - result = FeedResult(data=[], next_page=FeedResultNextPage(data={}), has_next_page=False) - - result_limit = limit - for item in self.items: - # Получаем данные из позиции мерджера. - item_result = await item.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=result_limit, - next_page=next_page, - redis_client=redis_client, - **params, - ) - - # Добавляем данные позиции к общему результату процентного мерджера. - result.data.extend(item_result.data) - - # Обновляем result_limit - result_limit -= len(item_result.data) - - # Если has_next_page = False, то проверяем has_next_page у позиции и, если необходимо, обновляем. - if not result.has_next_page and item_result.has_next_page: - result.has_next_page = True - - # Обновляем next_page. - result.next_page.data.update(item_result.next_page.data) - - # Если полученных данных хватает, то прерываем итерацию и возвращаем результат. - if result_limit <= 0: - break - - # Если в конфигурации указано "смешать" данные. - if self.shuffle: - shuffle(result.data) - - return result - - -class MergerPositional(BaseFeedConfigModel): - """ - Модель позиционного мерджера. - - Attributes: - merger_id уникальный ID мерджера. - type тип объекта - всегда "merger_positional". - positions позиции для вставки из мерджера / субфида "positional" [обязателен, если нет start, end, step]. - start начальная позиция [обязателен, если нет positions]. - end завершающая позиция [обязателен, если нет positions]. - step шаг позиций между "start" и "end" [обязателен, если нет positions]. - positional мерджер / субфид из которого берутся позиционные данные. - default мерджер / субфид из которого берутся остальные данные. - """ - - merger_id: str - type: Literal["merger_positional"] - positions: List[int] = [] - start: Optional[int] = None - end: Optional[int] = None - step: Optional[int] = None - positional: FeedTypes - default: FeedTypes - - @root_validator(skip_on_failure=True) - def validate_merger_positional(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if not values["positions"] and not all((values["start"], values["end"], values["step"])): - raise ValueError('Either "positions" or "start", "end", and "step" must be provided') - if values["start"] and values["positions"]: - if isinstance(values["start"], int) and values["start"] <= max(values["positions"]): - raise ValueError('"start" must be bigger than maximum value of "positions"') - if isinstance(values["start"], int) and isinstance(values["end"], int): - if values["end"] <= values["start"]: - raise ValueError('"end" must be bigger than "start"') - return values - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных в позиционном соотношении из данных позиций. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: для метода класса. - :return: список данных в процентном соотношении. - """ - - # Получаем данные "default". - default_res = await self.default.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=limit, - next_page=next_page, - redis_client=redis_client, - **params, - ) - - # Формируем результат позиционного мерджера. - result = FeedResult( - data=default_res.data, - next_page=FeedResultNextPage( - data={ - self.merger_id: FeedResultNextPageInside( - page=next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1, - after=next_page.data[self.merger_id].after if self.merger_id in next_page.data else None, - ) - }, - ), - has_next_page=default_res.has_next_page, - ) - - # Получаем список позиций с учетом текущей страницы. - positional_has_next_page = True - page_positions = [] - available_positions = range( - (result.next_page.data[self.merger_id].page - 1) * limit, - (result.next_page.data[self.merger_id].page * limit) + 1, - ) - for position in self.positions: - if position in available_positions: - page_positions.append(available_positions.index(position)) - - # Если конечная позиция текущей страницы больше или равна MAX позиции в конфигурации, то has_next_page = False - if max(available_positions) >= max(self.positions, default=0): - positional_has_next_page = False - - if self.start is not None and self.end is not None and self.step is not None: - # Если конечная позиция текущей страницы больше или равна конечной шаговой позиции, то has_next_page = False - positional_has_next_page = not max(available_positions) >= self.end - - for position in range(self.start, self.end, self.step): - if position in available_positions: - page_positions.append(available_positions.index(position)) - - # Получаем данные "positional". - pos_res = await self.positional.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=len(page_positions), - next_page=next_page, - redis_client=redis_client, - **params, - ) - - # Если has_next_page = False, то проверяем has_next_page у позиции и, если необходимо, обновляем. - if not result.has_next_page and all([positional_has_next_page, pos_res.has_next_page]): - result.has_next_page = True - - # Обновляем next_page. - result.next_page.data.update(default_res.next_page.data) - result.next_page.data.update(pos_res.next_page.data) - - # Формируем общие данные позиционного мерджера. - for i, post in enumerate(pos_res.data): - result.data = result.data[: page_positions[i] - 1] + [post] + result.data[page_positions[i] - 1 :] - - # Проверка на возврат данных в количестве не более limit. - if len(result.data) > limit: - result.data = result.data[:limit] - - # Обновляем страницу для курсора пагинации мерджера. - result.next_page.data[self.merger_id].page += 1 - - return result - - -class MergerPercentageItem(BaseModel): - """ - Модель позиции процентного мерджера. - - Attributes: - percentage процент позиции в мерджере. - data мерджер / субфид. - """ - - percentage: int - data: FeedTypes - - -class MergerPercentage(BaseFeedConfigModel): - """ - Модель процентного мерджера. - - Attributes: - merger_id уникальный ID мерджера. - type тип объекта - всегда "merger_percentage". - shuffle флаг для перемешивания полученных данных мерджера. - items позиции мерджера. - """ - - merger_id: str - type: Literal["merger_percentage"] - items: List[MergerPercentageItem] - shuffle: bool = False - - @staticmethod - async def _merge_items_data(items_data: List[List]) -> List: - """ - Метод для получения максимально равномерно распределенных данных позиций процентного мерджера. - - :param items_data: список со списками данных из каждой позиции. - :return: максимально равномерно распределенные данные позиций процентного мерджера. - """ - - # Формируем возвращаемый результат и список курсоров для списка каждой позиции. - result: List = [] - cursor: List[Dict] = [] - - # Получаем длину самого маленького списка и формируем курсор для каждого списка. - min_length = min(len(item_data) for item_data in items_data) or 1 - for item_data in items_data: - cursor.append( - { - "items": item_data, - "current": 0, - "size": round(len(item_data) / min_length), - } - ) - - # Получаем общий размер всех элементов всех списков и пока не получаем результат такого же размера - # производим операции по распределению элементов. - full_length = sum(len(item_data) for item_data in items_data) - while len(result) < full_length: - for item_cursor in cursor: - items = item_cursor["items"] - start = item_cursor["current"] - end = start + item_cursor["size"] if start + item_cursor["size"] < len(items) else len(items) - result.extend(items[start:end]) - item_cursor["current"] = end - - return result - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных в процентном соотношении из данных позиций. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: для метода класса. - :return: список данных в процентном соотношении. - """ - - # Формируем результат процентного мерджера. - result = FeedResult(data=[], next_page=FeedResultNextPage(data={}), has_next_page=False) - - items_data: List = [] - for item in self.items: - # Получаем данные из позиций процентного мерджера. - item_result = await item.data.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=limit * item.percentage // 100, - next_page=next_page, - redis_client=redis_client, - **params, - ) - - # Добавляем данные позиции в список данных позиций. - items_data.append(item_result.data) - - # Если has_next_page = False, то проверяем has_next_page у позиции и, если необходимо, обновляем. - if not result.has_next_page and item_result.has_next_page: - result.has_next_page = True - - # Обновляем next_page. - result.next_page.data.update(item_result.next_page.data) - - # Добавляем данные позиции к общему результату процентного мерджера. - result.data = await self._merge_items_data(items_data=items_data) - - # Если в конфигурации указано "смешать" данные. - if self.shuffle: - shuffle(result.data) - - return result - - -class MergerPercentageGradient(BaseFeedConfigModel): - """ - Модель процентного мерджера с градиентном. - - Attributes: - merger_id уникальный ID мерджера. - type тип объекта - всегда "merger_percentage_gradient". - item_from мерджер / субфид из которого начинается "перетекание" градиента. - item_to мерджер / субфид в который "перетекает" градиент. - step изменение в % соотношения из item_from в item_to. - size_to_step шаг для применения изменений % соотношения (например, через каждые 30 позиций). - shuffle флаг для перемешивания полученных данных мерджера. - """ - - merger_id: str - type: Literal["merger_percentage_gradient"] - item_from: MergerPercentageItem - item_to: MergerPercentageItem - step: int - size_to_step: int - shuffle: bool = False - - @root_validator(skip_on_failure=True) - def validate_merger_percentage_gradient(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if values["step"] < 1 or values["step"] > 100: - raise ValueError('"step" must be in range from 1 to 100') - if values["size_to_step"] < 1: - raise ValueError('"size_to_step" must be bigger than 1') - return values - - async def _calculate_limits_and_percents(self, page: int, limit: int) -> Dict: - """ - Метод для получения списка лимитов данных с процентным соотношением позиций item_from & item_to, - учитывая градиентное изменение соотношений. - - :param page: порядковый номер страницы. - :param limit: общий лимит данных для страницы. - :return: список лимитов данных с процентным соотношением позиций item_from & item_to. - """ - - result: Dict = { - "limit_from": 0, - "limit_to": 0, - "percentages": [], - } - - percentage_from = self.item_from.percentage - percentage_to = self.item_to.percentage - start_position = limit * (page - 1) - first_iter = True - - for i in range(self.size_to_step, limit * page + self.size_to_step, self.size_to_step): - # При первой итерации и percentage_to >= 100 не меняем соотношение % между позициями. - if not first_iter and percentage_to < 100: - # Меняем процентное соотношение позиций на "шаг", указанный в конфигурации. - percentage_from -= self.step - percentage_to += self.step - - # Если процентное соотношение вышло за 100+, то устанавливаем предельные значения. - if percentage_to > 100 or percentage_from < 0: - percentage_from = 0 - percentage_to = 100 - - # Если индекс итерации по величине больше стартовой позиции согласно переданной странице, - # то начинаем обработку. - if i > start_position: - # Рассчитываем лимит получения данных для конкретной итерации. - iter_limit = (limit * page - start_position) if i > limit * page else (i - start_position) - start_position = i - - # Формируем результат для каждой итерации и добавляем в возвращаемый список, но если процентное - # соотношение у последней итерации 0 - 100, то добавляем лимит к ней. - if result["percentages"] and result["percentages"][-1]["to"] >= 100: - result["limit_to"] += iter_limit - result["percentages"][-1]["limit"] += iter_limit - else: - result["limit_from"] += iter_limit * percentage_from // 100 - result["limit_to"] += iter_limit * percentage_to // 100 - iter_result = {"limit": iter_limit, "from": percentage_from, "to": percentage_to} - result["percentages"].append(iter_result) - - # Если первая итерация цикла - if first_iter: - first_iter = False - - return result - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных в процентном соотношении с градиентом из данных позиций. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: для метода класса. - :return: список данных в процентном соотношении. - """ - - # Формируем результат процентного мерджера с градиентом. - result = FeedResult( - data=[], - next_page=FeedResultNextPage( - data={ - self.merger_id: FeedResultNextPageInside( - page=next_page.data[self.merger_id].page if self.merger_id in next_page.data else 1, - after=next_page.data[self.merger_id].after if self.merger_id in next_page.data else None, - ) - }, - ), - has_next_page=False, - ) - - # Получаем список лимитов данных и соотношений согласно странице и градиенту. - limits_and_percents = await self._calculate_limits_and_percents( - page=result.next_page.data[self.merger_id].page, - limit=limit, - ) - - # Получаем данные из позиций в процентном соотношений. - item_from = await self.item_from.data.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=limits_and_percents["limit_from"], - next_page=next_page, - redis_client=redis_client, - **params, - ) - item_to = await self.item_to.data.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=limits_and_percents["limit_to"], - next_page=next_page, - redis_client=redis_client, - **params, - ) - - from_start_index = 0 - to_start_index = 0 - for lp_data in limits_and_percents["percentages"]: - # Высчитываем лимиты для каждой позиции исходя из процентного соотношения. - from_end_index = (lp_data["limit"] * lp_data["from"] // 100) + from_start_index - to_end_index = (lp_data["limit"] * lp_data["to"] // 100) + to_start_index - - # Добавляем данные позиции к общему результату процентного мерджера с градиентом. - result.data.extend(item_from.data[from_start_index:from_end_index]) - result.data.extend(item_to.data[to_start_index:to_end_index]) - - # Обновляем стартовые индексы. - from_start_index = from_end_index - to_start_index = to_end_index - - # Обновляем next_page. - result.next_page.data.update(item_from.next_page.data) - result.next_page.data.update(item_to.next_page.data) - - # Если has_next_page = False, то проверяем has_next_page у позиций и, если необходимо, обновляем. - if any([item_from.has_next_page, item_to.has_next_page]): - result.has_next_page = True - - # Если в конфигурации указано "смешать" данные. - if self.shuffle: - shuffle(result.data) - - # Обновляем страницу для курсора пагинации мерджера. - result.next_page.data[self.merger_id].page += 1 - - return result - - -class MergerAppendDistribute(BaseFeedConfigModel): - """ - Модель мерджера, равномерно распределяющего данные по ключу. - - Attributes: - merger_id уникальный ID мерджера. - type тип объекта - всегда "merger_distribute". - items позиции мерджера. - distribution_key ключ для распределения данных мерджера. - sorting_key ключ сортировки. - sorting_desc флаг сортировки по убыванию. - """ - - merger_id: str - type: Literal["merger_distribute"] - items: List[FeedTypes] - distribution_key: str - sorting_key: Optional[str] = None - sorting_desc: bool = False - - @no_type_check - async def _uniform_distribute(self, data: list) -> list: - # Сортируем записи глобально по `created_at` в порядке убывания - if self.sorting_key: - data = sorted(data, key=lambda x: x[self.sorting_key], reverse=self.sorting_desc) - - # Группируем записи по `distribution_key` - grouped_entries = defaultdict(deque) - for entry in data: - grouped_entries[entry[self.distribution_key]].append(entry) - result = [] - prev_profile_id = None - while any(grouped_entries.values()): - for profile_id in list(grouped_entries.keys()): - if grouped_entries[profile_id]: - # Если текущий `distribution_key` отличается от предыдущего или он последний, берем его - if profile_id != prev_profile_id or len(grouped_entries) == 1: - result.append(grouped_entries[profile_id].popleft()) - prev_profile_id = profile_id - if not grouped_entries[profile_id]: # Если записи закончились, удаляем ключ из группы - del grouped_entries[profile_id] - else: - del grouped_entries[profile_id] - - return result - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных методом append. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: для метода класса. - :return: список данных методом append. - """ - - # Формируем результат append мерджера. - result = FeedResult(data=[], next_page=FeedResultNextPage(data={}), has_next_page=False) - - result_limit = limit - for item in self.items: - # Получаем данные из позиции мерджера. - item_result = await item.get_data( - methods_dict=methods_dict, - user_id=user_id, - limit=result_limit, - next_page=next_page, - redis_client=redis_client, - **params, - ) - - # Добавляем данные позиции к общему результату процентного мерджера. - result.data.extend(item_result.data) - - # Обновляем result_limit - result_limit -= len(item_result.data) - - # Если has_next_page = False, то проверяем has_next_page у позиции и, если необходимо, обновляем. - if not result.has_next_page and item_result.has_next_page: - result.has_next_page = True - - # Обновляем next_page. - result.next_page.data.update(item_result.next_page.data) - - # Если полученных данных хватает, то прерываем итерацию и возвращаем результат. - if result_limit <= 0: - break - - # Распределяем данные равномерно по ключу. - result.data = await self._uniform_distribute(result.data) - return result - - -class SubFeed(BaseFeedConfigModel): - """ - Модель субфида. - - Attributes: - subfeed_id уникальный ID субфида. - type тип объекта - всегда "subfeed". - method_name название клиентского метода для получения данных субфида. - subfeed_params статичные параметры для метода субфида. - shuffle флаг для перемешивания полученных данных мерджера. - """ - - subfeed_id: str - type: Literal["subfeed"] - method_name: str - subfeed_params: Dict[str, Any] = {} - raise_error: Optional[bool] = True - shuffle: bool = False - - async def get_data( - self, - methods_dict: Dict[str, Callable], - user_id: Any, - limit: int, - next_page: FeedResultNextPage, - redis_client: Optional[Union[redis.Redis, AsyncRedis]] = None, - **params: Any, - ) -> FeedResult: - """ - Метод для получения данных из метода субфида. - - :param methods_dict: словарь с используемыми методами. - :param user_id: ID объекта для получения данных (например, ID пользователя). - :param limit: кол-во элементов. - :param next_page: курсор пагинации. - :param redis_client: объект клиента Redis (для конфигурации с view_session мерджером). - :param params: параметры для метода. - :return: список данных. - """ - - # Формируем next_page конкретного субфида. - subfeed_next_page = FeedResultNextPageInside( - page=next_page.data[self.subfeed_id].page if self.subfeed_id in next_page.data else 1, - after=next_page.data[self.subfeed_id].after if self.subfeed_id in next_page.data else None, - ) - - # Формируем params для функции субфида. - method_args = inspect.getfullargspec(methods_dict[self.method_name]).args - method_params: Dict[str, Any] = {} - for arg in method_args: - if arg in params: - method_params[arg] = params[arg] - - # Получаем результат функции клиента в формате SubFeedResult. - try: - method_result = await methods_dict[self.method_name]( - user_id=user_id, - limit=limit, - next_page=subfeed_next_page, - **method_params, - **self.subfeed_params, - ) - except (Exception,) as _: - if self.raise_error: - raise - - method_result = FeedResultClient( - data=[], - next_page=subfeed_next_page, - has_next_page=False, - ) - - if not isinstance(method_result, FeedResultClient): - raise TypeError('SubFeed function must return "FeedResultClient" instance.') - - # Если в конфигурации указано "смешать" данные. - if self.shuffle: - shuffle(method_result.data) - - result = FeedResult( - data=method_result.data, - next_page=FeedResultNextPage(data={self.subfeed_id: method_result.next_page}), - has_next_page=method_result.has_next_page, - ) - return result - - class FeedConfig(BaseModel): - """ - Модель конфигурации фида. - - Attributes: - version версия конфигурации. - view_session флаг использования механизма расчета всего фида сразу и сохранения в кэш. - session_size размер кэшируемого фида (limit получения данных для сохранения в кэш). - session_live_time срок хранения в кэше для кэшируемого фида (в секундах). - feed мерджер или субфид. - """ + """Top-level feed config model.""" version: str feed: FeedTypes -# Update Forward Refs -MergerPositional.update_forward_refs() -MergerPercentage.update_forward_refs() -SubFeed.update_forward_refs() -MergerPercentageItem.update_forward_refs() -MergerAppend.update_forward_refs() -MergerAppendDistribute.update_forward_refs() -MergerPercentageGradient.update_forward_refs() -MergerViewSession.update_forward_refs() +def _rebuild_model(model: Any) -> None: + """Resolve forward refs across modules (Pydantic v1/v2 compatible).""" + + if hasattr(model, "model_rebuild"): + model.model_rebuild(force=True, _types_namespace={"FeedTypes": FeedTypes}) + else: + model.update_forward_refs(FeedTypes=FeedTypes) + + +for _m in ( + MergerPositional, + MergerPercentage, + MergerPercentageItem, + MergerAppend, + MergerAppendDistribute, + MergerPercentageGradient, + MergerViewSession, + MergerDeduplication, + SubFeed, + FeedConfig, +): + _rebuild_model(_m) + + +__all__ = [ + "BaseFeedConfigModel", + "FeedResult", + "FeedResultClient", + "FeedResultNextPage", + "FeedResultNextPageInside", + "SubFeed", + "MergerAppend", + "MergerAppendDistribute", + "MergerDeduplication", + "MergerPercentage", + "MergerPercentageGradient", + "MergerPercentageItem", + "MergerPositional", + "MergerViewSession", + "FeedTypes", + "FeedConfig", +] diff --git a/tests/fixtures/configs.py b/tests/fixtures/configs.py index 8c96e4e..6aff5cb 100644 --- a/tests/fixtures/configs.py +++ b/tests/fixtures/configs.py @@ -86,3 +86,38 @@ }, }, } + + +PARSING_DEDUP_CONFIG_FIXTURE = { + "version": "1", + "feed": { + "merger_id": "merger_deduplication_parsing_example", + "type": "merger_deduplication", + "dedup_key": "id", + "state_backend": "cursor", + "cursor_compress": True, + "data": { + "merger_id": "merger_percentage_inside_dedup_parsing_example", + "type": "merger_percentage", + "shuffle": False, + "items": [ + { + "percentage": 50, + "data": { + "subfeed_id": "subfeed_dedup_a", + "type": "subfeed", + "method_name": "posted", + }, + }, + { + "percentage": 50, + "data": { + "subfeed_id": "subfeed_dedup_b", + "type": "subfeed", + "method_name": "posted", + }, + }, + ], + }, + }, +} diff --git a/tests/fixtures/dedup_helpers.py b/tests/fixtures/dedup_helpers.py new file mode 100644 index 0000000..eb3fbdc --- /dev/null +++ b/tests/fixtures/dedup_helpers.py @@ -0,0 +1,441 @@ +from smartfeed.schemas import FeedResultClient, FeedResultNextPage + + +def _effective_limit(limit, max_per_call): + effective_limit = limit + if isinstance(max_per_call, int) and max_per_call > 0: + effective_limit = min(effective_limit, max_per_call) + return effective_limit + + +def make_offset_paged_method(items, *, max_per_call=None): + async def _method(user_id, limit, next_page, **kwargs): # pylint: disable=unused-argument + offset = int(next_page.after or 0) + effective_limit = _effective_limit(limit, max_per_call) + result_data = items[offset : offset + effective_limit] + next_page.after = offset + len(result_data) + next_page.page += 1 + has_next_page = (offset + len(result_data)) < len(items) + return FeedResultClient(data=result_data, next_page=next_page, has_next_page=has_next_page) + + return _method + + +def make_string_after_paged_method(items, *, max_per_call=None, after_field="created_at"): + """A subfeed method whose cursor is a string (e.g. timestamp). + + Cursor semantics: `after` is the last returned `created_at` value (monotonic). + """ + + async def _method(user_id, limit, next_page, **kwargs): # pylint: disable=unused-argument + effective_limit = _effective_limit(limit, max_per_call) + + after = next_page.after + start_idx = 0 + if isinstance(after, str) and after: + # Find first item with created_at > after + for i, item in enumerate(items): + if str(item[after_field]) > after: + start_idx = i + break + else: + start_idx = len(items) + + result_data = items[start_idx : start_idx + effective_limit] + has_next_page = (start_idx + len(result_data)) < len(items) + + if result_data: + next_page.after = str(result_data[-1][after_field]) + next_page.page += 1 + return FeedResultClient(data=result_data, next_page=next_page, has_next_page=has_next_page) + + return _method + + +def make_profile_dict_after_method( + profiles_to_items, + *, + max_per_call=None, + after_key="after", +): + """A subfeed method whose cursor is a dict of per-profile offsets. + + Example shape: after = {"p1": 0, "p2": 0} + Cursor semantics: each profile offset increments as items are *read*. + """ + + profile_ids = list(profiles_to_items.keys()) + + async def _method(user_id, limit, next_page, **kwargs): # pylint: disable=unused-argument + effective_limit = _effective_limit(limit, max_per_call) + + after = next_page.after + if not isinstance(after, dict): + after = {pid: 0 for pid in profile_ids} + else: + after = dict(after) + for pid in profile_ids: + after.setdefault(pid, 0) + + result = [] + has_next_page = False + + # Build a cyclic iteration over profiles. + active_profiles = [pid for pid in profile_ids] + + i = 0 + while active_profiles and len(result) < effective_limit: + pid = active_profiles[i % len(active_profiles)] + idx = after.get(pid, 0) + items = profiles_to_items.get(pid, []) + + if idx >= len(items): + # This profile is exhausted. + active_profiles.remove(pid) + continue + + result.append(items[idx]) + after[pid] = idx + 1 + i += 1 + + # Determine if any profile still has unread items. + for pid in profile_ids: + if after.get(pid, 0) < len(profiles_to_items.get(pid, [])): + has_next_page = True + break + + next_page.after = after + next_page.page += 1 + return FeedResultClient(data=result, next_page=next_page, has_next_page=has_next_page) + + return _method + + +def _assert_cursor_monotonic_if_present(res_1, res_2, keys): + for key in keys: + if key not in res_1.next_page.data: + continue + + assert key in res_2.next_page.data + + after_1 = res_1.next_page.data[key].after + after_2 = res_2.next_page.data[key].after + + if after_1 is None or after_2 is None: + continue + + if isinstance(after_1, int) and isinstance(after_2, int): + assert after_2 >= after_1 + continue + + if isinstance(after_1, dict) and isinstance(after_2, dict): + continue + + try: + assert after_2 >= after_1 + except TypeError: + pass + + +def _sources(data): + return [x.get("src") for x in data] + + +def _ids(data): + return [x.get("id") for x in data] + + +def _assert_no_dupes_in_page(data): + ids = _ids(data) + assert len(ids) == len(set(ids)) + + +def _assert_pages_no_overlap(res_1, res_2): + assert not (set(_ids(res_1.data)) & set(_ids(res_2.data))) + + +def _assert_two_pages_no_dupes(res_1, res_2): + _assert_no_dupes_in_page(res_1.data) + _assert_no_dupes_in_page(res_2.data) + _assert_pages_no_overlap(res_1, res_2) + + +def _assert_sources_at_positions(data, positions, expected_src): + sources = _sources(data) + for pos in positions: + assert sources[pos - 1] == expected_src + + +def make_items(src, start, end, *, user_id_mod=None, id_offset=0, extra=None): + items = [] + for i in range(start, end): + item_id = id_offset + i + item = {"id": item_id, "src": src} + if user_id_mod is not None: + item["user_id"] = f"u{item_id % user_id_mod}" + if extra: + item.update(extra) + items.append(item) + return items + + +def _subfeed(subfeed_id, method_name, *, dedup_priority=None): + data = {"subfeed_id": subfeed_id, "type": "subfeed", "method_name": method_name} + if dedup_priority is not None: + data["dedup_priority"] = dedup_priority + return data + + +def _dedup_config(merger_id, data, *, dedup_key="id", state_backend="cursor", cursor_compress=True, **kwargs): + config = { + "merger_id": merger_id, + "type": "merger_deduplication", + "dedup_key": dedup_key, + "state_backend": state_backend, + "cursor_compress": cursor_compress, + "data": data, + } + config.update(kwargs) + return config + + +def _percentage_config(merger_id, items, *, shuffle=False): + return {"merger_id": merger_id, "type": "merger_percentage", "shuffle": shuffle, "items": items} + + +def _append_config(merger_id, items, *, shuffle=False): + return {"merger_id": merger_id, "type": "merger_append", "shuffle": shuffle, "items": items} + + +def _distribute_config(merger_id, items, *, distribution_key="user_id"): + return { + "merger_id": merger_id, + "type": "merger_distribute", + "distribution_key": distribution_key, + "items": items, + } + + +def _positional_config(merger_id, *, positions, positional, default): + return { + "merger_id": merger_id, + "type": "merger_positional", + "positions": positions, + "positional": positional, + "default": default, + } + + +def _gradient_config( + merger_id, + *, + item_from, + item_to, + step, + size_to_step, + shuffle=False, +): + return { + "merger_id": merger_id, + "type": "merger_percentage_gradient", + "item_from": item_from, + "item_to": item_to, + "step": step, + "size_to_step": size_to_step, + "shuffle": shuffle, + } + + +async def _run_two_pages(merger, methods_dict, limit, *, next_page=None, **kwargs): + if next_page is None: + next_page = FeedResultNextPage(data={}) + res_1 = await merger.get_data(methods_dict=methods_dict, user_id="u", limit=limit, next_page=next_page, **kwargs) + res_2 = await merger.get_data( + methods_dict=methods_dict, user_id="u", limit=limit, next_page=res_1.next_page, **kwargs + ) + return res_1, res_2 + + +def _percentage_items(first, second, *, first_pct=50, second_pct=50): + return [ + {"percentage": first_pct, "data": first}, + {"percentage": second_pct, "data": second}, + ] + + +def _two_subfeed_spec(*, name="a", subfeed_id="sf_a", max_per_call=None, dedup_priority=None): + return { + "name": name, + "subfeed_id": subfeed_id, + "max_per_call": max_per_call, + "dedup_priority": dedup_priority, + } + + +def _build_two_subfeed_methods(items_a, items_b, *, spec_a=None, spec_b=None): + if spec_a is None: + spec_a = _two_subfeed_spec() + if spec_b is None: + spec_b = _two_subfeed_spec(name="b", subfeed_id="sf_b") + + methods_dict = { + spec_a["name"]: make_offset_paged_method(items_a, max_per_call=spec_a["max_per_call"]), + spec_b["name"]: make_offset_paged_method(items_b, max_per_call=spec_b["max_per_call"]), + } + subfeed_a = _subfeed(spec_a["subfeed_id"], spec_a["name"], dedup_priority=spec_a["dedup_priority"]) + subfeed_b = _subfeed(spec_b["subfeed_id"], spec_b["name"], dedup_priority=spec_b["dedup_priority"]) + return methods_dict, subfeed_a, subfeed_b + + +def _build_two_subfeed_dedup_merger( + *, + items_a, + items_b, + child_builder, + merger_id, + spec_a=None, + spec_b=None, + dedup_kwargs=None, +): + methods_dict, subfeed_a, subfeed_b = _build_two_subfeed_methods( + items_a, + items_b, + spec_a=spec_a, + spec_b=spec_b, + ) + config = _dedup_config(merger_id, child_builder(subfeed_a, subfeed_b), **(dedup_kwargs or {})) + return config, methods_dict, subfeed_a, subfeed_b + + +def _build_deep_positional_pct_dedup_merger( + *, + items_p, + items_d1, + items_d2, + dedup_merger_id, + pos_merger_id, + pct_merger_id, + positions, + overfetch_factor=None, + max_refill_loops=None, +): + methods_dict = { + "p": make_offset_paged_method(items_p), + "d1": make_offset_paged_method(items_d1), + "d2": make_offset_paged_method(items_d2), + } + + dedup_kwargs = {} + if overfetch_factor is not None: + dedup_kwargs["overfetch_factor"] = overfetch_factor + if max_refill_loops is not None: + dedup_kwargs["max_refill_loops"] = max_refill_loops + + config = _dedup_config( + dedup_merger_id, + _positional_config( + pos_merger_id, + positions=positions, + positional=_subfeed("sf_p", "p"), + default=_percentage_config( + pct_merger_id, + items=_percentage_items(_subfeed("sf_d1", "d1"), _subfeed("sf_d2", "d2")), + ), + ), + **dedup_kwargs, + ) + return config, methods_dict + + +def _inner_append_config(*, merger_id: str, subfeed_id: str, method_name: str, dedup_priority: int): + return { + "merger_id": merger_id, + "type": "merger_append", + # Important: dedup deletion priority must be visible at this node so parent mergers + # can fetch higher-priority subtrees first when a dedup wrapper is active. + "dedup_priority": dedup_priority, + "shuffle": False, + "items": [ + { + "subfeed_id": subfeed_id, + "type": "subfeed", + "method_name": method_name, + "dedup_priority": dedup_priority, + } + ], + } + + +def _build_deep_priority_tree_for_merger_type(*, merger_type: str): + """Return a deep tree config where low/high leaves overlap by id. + + The inner leaves are wrapped into an append merger to ensure a "deep" tree even + when the outer merger is flat. + """ + + low = _inner_append_config(merger_id="inner_low", subfeed_id="sf_low", method_name="low", dedup_priority=0) + high = _inner_append_config(merger_id="inner_high", subfeed_id="sf_high", method_name="high", dedup_priority=100) + + if merger_type == "merger_append": + return { + "merger_id": "outer_append", + "type": "merger_append", + "shuffle": False, + # Put low first intentionally; priority must still make high win for overlapping ids. + "items": [low, high], + } + + if merger_type == "merger_distribute": + return { + "merger_id": "outer_dist", + "type": "merger_distribute", + "distribution_key": "user_id", + # Put low first intentionally. + "items": [low, high], + } + + if merger_type == "merger_percentage": + return { + "merger_id": "outer_pct", + "type": "merger_percentage", + "shuffle": False, + "items": [ + {"percentage": 50, "data": low}, + {"percentage": 50, "data": high}, + ], + } + + if merger_type == "merger_percentage_gradient": + return { + "merger_id": "outer_grad", + "type": "merger_percentage_gradient", + "item_from": {"percentage": 60, "data": low}, + "item_to": {"percentage": 40, "data": high}, + "step": 20, + "size_to_step": 5, + "shuffle": False, + } + + if merger_type == "merger_positional": + # High priority on positional branch so it must win duplicates. + high_pos = _inner_append_config( + merger_id="inner_pos_high", + subfeed_id="sf_high", + method_name="high", + dedup_priority=100, + ) + low_def = _inner_append_config( + merger_id="inner_def_low", + subfeed_id="sf_low", + method_name="low", + dedup_priority=0, + ) + return { + "merger_id": "outer_pos", + "type": "merger_positional", + "positions": [1, 3, 5, 7, 9, 11], + "positional": high_pos, + "default": low_def, + } + + raise AssertionError(f"Unknown merger_type: {merger_type}") diff --git a/tests/fixtures/redis.py b/tests/fixtures/redis.py index b98695e..5c9af72 100644 --- a/tests/fixtures/redis.py +++ b/tests/fixtures/redis.py @@ -1,10 +1,30 @@ import pytest +import pytest_asyncio import redis from redis.asyncio import Redis as AsyncRedis -@pytest.fixture(scope="function") -def redis_client(request): +@pytest_asyncio.fixture(scope="function") +async def redis_client(request): + """Provide a Redis client for tests. + + If Redis is not available on localhost:6379, skip tests that depend on it. + """ + if request.param == "async": - return AsyncRedis(host="localhost", port=6379) - return redis.Redis(host="localhost", port=6379, db=0) + client = AsyncRedis(host="localhost", port=6379) + try: + await client.ping() + except Exception: + pytest.skip("Redis is not available on localhost:6379") + yield client + await client.aclose() + return + + client = redis.Redis(host="localhost", port=6379, db=0) + try: + client.ping() + except Exception: + pytest.skip("Redis is not available on localhost:6379") + yield client + client.close() diff --git a/tests/test_async_loop_blocks_trace.py b/tests/test_async_loop_blocks_trace.py new file mode 100644 index 0000000..a1bafd9 --- /dev/null +++ b/tests/test_async_loop_blocks_trace.py @@ -0,0 +1,467 @@ +import asyncio +import json +import os +import time +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Dict, List, Optional + +import pytest + +from smartfeed.schemas import FeedResultNextPage, MergerDeduplication +from tests.fixtures import dedup_helpers as dh +from tests.fixtures.redis import redis_client # noqa: F401 +from tests.utils import parse_model + + +def _now_us() -> int: + return time.perf_counter_ns() // 1000 + + +@dataclass +class ChromeTraceRecorder: + """Writes Chrome Trace Events JSON for chrome://tracing. + + This is intentionally tiny and test-only: no production dependencies. + """ + + pid: int = 1 + events: List[Dict[str, Any]] = field(default_factory=list) + + def _emit(self, event: Dict[str, Any]) -> None: + self.events.append(event) + + def begin(self, name: str, *, tid: int, ts_us: Optional[int] = None, args: Optional[Dict[str, Any]] = None) -> None: + self._emit( + { + "name": name, + "ph": "B", + "ts": int(_now_us() if ts_us is None else ts_us), + "pid": int(self.pid), + "tid": int(tid), + "args": args or {}, + } + ) + + def end(self, name: str, *, tid: int, ts_us: Optional[int] = None, args: Optional[Dict[str, Any]] = None) -> None: + self._emit( + { + "name": name, + "ph": "E", + "ts": int(_now_us() if ts_us is None else ts_us), + "pid": int(self.pid), + "tid": int(tid), + "args": args or {}, + } + ) + + def instant( + self, name: str, *, tid: int, ts_us: Optional[int] = None, args: Optional[Dict[str, Any]] = None + ) -> None: + self._emit( + { + "name": name, + "ph": "i", + "s": "t", + "ts": int(_now_us() if ts_us is None else ts_us), + "pid": int(self.pid), + "tid": int(tid), + "args": args or {}, + } + ) + + def write(self, path: str) -> None: + payload = {"traceEvents": self.events} + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f) + + +class LoopBlockMonitor: + """Detects event-loop blocking by measuring scheduling lag. + + If the event loop is blocked by long sync work, a periodic sleeper will wake + up late; we track the maximum observed lag. + """ + + def __init__(self, *, sample_interval_s: float = 0.01, block_threshold_s: float = 0.25) -> None: + self.sample_interval_s = float(sample_interval_s) + self.block_threshold_s = float(block_threshold_s) + self.max_lag_s: float = 0.0 + self.block_events: List[float] = [] + self._task: Optional[asyncio.Task[None]] = None + self._stop = asyncio.Event() + + async def __aenter__(self) -> "LoopBlockMonitor": + self._stop.clear() + self._task = asyncio.create_task(self._run()) + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: # type: ignore[override] + self._stop.set() + if self._task is not None: + await self._task + + async def _run(self) -> None: + loop = asyncio.get_running_loop() + expected = loop.time() + self.sample_interval_s + while not self._stop.is_set(): + await asyncio.sleep(self.sample_interval_s) + now = loop.time() + lag = max(0.0, now - expected) + expected = now + self.sample_interval_s + self.max_lag_s = max(self.max_lag_s, lag) + if lag >= self.block_threshold_s: + self.block_events.append(lag) + + +@dataclass +class LeafConcurrencyTracker: + """Tracks how many leaf calls are in-flight concurrently.""" + + current: int = 0 + peak: int = 0 + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + async def enter(self) -> int: + async with self._lock: + self.current += 1 + if self.current > self.peak: + self.peak = self.current + return self.current + + async def exit(self) -> int: + async with self._lock: + self.current = max(0, self.current - 1) + return self.current + + +def _trace_wrap_awaitable( + rec: ChromeTraceRecorder, name: str, awaitable: Awaitable[Any], *, args: Dict[str, Any] +) -> Awaitable[Any]: + async def _wrapped() -> Any: + task = asyncio.current_task() + tid = id(task) if task is not None else 0 + rec.begin(name, tid=tid, args=args) + try: + return await awaitable + finally: + rec.end(name, tid=tid) + + return _wrapped() + + +def _wrap_method_latency(method: Callable[..., Awaitable[Any]], *, latency_s: float) -> Callable[..., Awaitable[Any]]: + async def _wrapped(*args: Any, **kwargs: Any) -> Any: + await asyncio.sleep(latency_s) + return await method(*args, **kwargs) + + return _wrapped + + +def _wrap_leaf_method_traced( + *, + rec: ChromeTraceRecorder, + key: str, + method: Callable[..., Awaitable[Any]], + latency_s: float, + concurrency: LeafConcurrencyTracker, +) -> Callable[..., Awaitable[Any]]: + async def _wrapped(user_id: Any, limit: int, next_page: Any, **kwargs: Any) -> Any: + task = asyncio.current_task() + tid = id(task) if task is not None else 0 + + page = getattr(next_page, "page", None) + after = getattr(next_page, "after", None) + after_type = type(after).__name__ + + if after is None: + after_preview = None + else: + after_preview = str(after) + if len(after_preview) > 120: + after_preview = after_preview[:117] + "..." + + span = f"leaf.{key}" + + in_flight = await concurrency.enter() + rec.begin( + span, + tid=tid, + args={ + "key": key, + "limit": int(limit), + "page": page, + "after_type": after_type, + "after_preview": after_preview, + "in_flight": int(in_flight), + }, + ) + try: + if latency_s > 0: + await asyncio.sleep(float(latency_s)) + return await method(user_id, limit, next_page, **kwargs) + finally: + rec.end(span, tid=tid) + await concurrency.exit() + + return _wrapped + + +@pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) +@pytest.mark.asyncio +async def test_async_loop_blocks_and_trace_for_deep_tree_all_mergers(redis_client, monkeypatch, tmp_path) -> None: + """A smoke-test for detecting async loop blocks + visualizing concurrency. + + - Builds one deep tree that includes ALL merger types. + - Simulates 2 sequential requests (fresh + next page). + - Forces refills via positional under-fetch (`max_per_call=1`). + - Records loop scheduling lag (blocks/hangs) and optionally exports a Chrome trace. + + Set `SMARTFEED_CHROME_TRACE=/path/to/trace.json` to write a trace. + Open it in Chrome via chrome://tracing. + """ + + # Keep IDs disjoint across sources so "no dupes" is stable. + # Refill waves are forced via max_per_call limits (under-fetch), not via dedup collisions. + items_a = dh.make_items("A", 1, 400, user_id_mod=5, id_offset=1_000) + items_b = dh.make_items("B", 1, 400, user_id_mod=5, id_offset=10_000) + + # Distribute branch: needs distribution_key present (user_id). + items_posted_1 = dh.make_items("posted_1", 1, 80, user_id_mod=3, id_offset=20_000) + items_posted_2 = dh.make_items("posted_2", 1, 120, user_id_mod=3, id_offset=21_000) + + # Gradient branch: overlapping ids again. + items_g1 = dh.make_items("G1", 1, 250, user_id_mod=7, id_offset=30_000) + items_g2 = dh.make_items("G2", 1, 250, user_id_mod=7, id_offset=40_000) + + # View-session leaf. + items_vs = dh.make_items("VS", 1, 160, user_id_mod=11, id_offset=50_000) + + # Positional leaf that intentionally under-fetches to force refill waves. + items_pos_leaf = dh.make_items("POS", 1, 500, user_id_mod=13, id_offset=60_000) + + # --- tracing (test-only monkeypatch) --- + rec = ChromeTraceRecorder() + leaf_concurrency = LeafConcurrencyTracker() + pos_leaf_calls = {"count": 0} + + pos_leaf_base = dh.make_offset_paged_method(items_pos_leaf, max_per_call=1) + + async def _pos_leaf_counted(user_id: Any, limit: int, next_page: Any, **kwargs: Any) -> Any: + pos_leaf_calls["count"] += 1 + return await pos_leaf_base(user_id, limit, next_page, **kwargs) + + # Leaf method tracing: wrap the *actual* subfeed method calls. + # These spans are what you want to inspect for "are leaf calls parallel?". + leaf_latency_s = 0.02 + methods_dict = { + "a": _wrap_leaf_method_traced( + rec=rec, + key="a", + method=dh.make_offset_paged_method(items_a), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + "b": _wrap_leaf_method_traced( + rec=rec, + key="b", + method=dh.make_offset_paged_method(items_b), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + "posted_1": _wrap_leaf_method_traced( + rec=rec, + key="posted_1", + method=dh.make_offset_paged_method(items_posted_1), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + "posted_2": _wrap_leaf_method_traced( + rec=rec, + key="posted_2", + method=dh.make_offset_paged_method(items_posted_2), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + "g1": _wrap_leaf_method_traced( + rec=rec, + key="g1", + method=dh.make_offset_paged_method(items_g1), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + "g2": _wrap_leaf_method_traced( + rec=rec, + key="g2", + method=dh.make_offset_paged_method(items_g2), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + "vs": _wrap_leaf_method_traced( + rec=rec, + key="vs", + method=dh.make_offset_paged_method(items_vs), + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + # Fetch only 1 item per call even if demand is higher -> triggers refill loops. + "pos_leaf": _wrap_leaf_method_traced( + rec=rec, + key="pos_leaf", + method=_pos_leaf_counted, + latency_s=leaf_latency_s, + concurrency=leaf_concurrency, + ), + } + + view_session_cfg = { + "merger_id": "vs_all", + "type": "merger_view_session", + "session_size": 100, + "session_live_time": 60, + "deduplicate": True, + "dedup_key": "id", + "data": dh._subfeed("sf_vs", "vs"), + } + + pct_cfg = dh._percentage_config( + "pct_all", + items=dh._percentage_items(dh._subfeed("sf_a", "a"), dh._subfeed("sf_b", "b"), first_pct=50, second_pct=50), + ) + + pos_cfg = dh._positional_config( + "pos_all", + # Ensure positional inserts appear across pages for limit~12. + # Use even positions so the schedule starts with the default branch; + # this keeps ordering deterministic. + positions=[2, 4, 6, 8, 10, 12, 14, 16, 18], + positional=dh._subfeed("sf_pos_leaf", "pos_leaf"), + default=pct_cfg, + ) + + dist_cfg = dh._distribute_config( + "dist_all", + items=[dh._subfeed("sf_posted_1", "posted_1"), dh._subfeed("sf_posted_2", "posted_2")], + distribution_key="user_id", + ) + + grad_cfg = dh._gradient_config( + "grad_all", + item_from={"percentage": 70, "data": dh._subfeed("sf_g1", "g1")}, + item_to={"percentage": 30, "data": dh._subfeed("sf_g2", "g2")}, + step=10, + size_to_step=5, + shuffle=False, + ) + + # Include all merger types as siblings so they are executed (and visible in trace), + # while keeping the main output driven by the first branch. + deep_tree = dh._append_config("append_all", [pos_cfg, view_session_cfg, dist_cfg, grad_cfg]) + config = dh._dedup_config( + "dedup_all", + deep_tree, + dedup_key="id", + state_backend="cursor", + overfetch_factor=3, + max_refill_loops=50, + ) + merger = parse_model(MergerDeduplication, config) + + # Patch Executor.gather to wrap each awaitable for Chrome trace. + from smartfeed.execution.executor import Executor # local import for monkeypatch + + original_gather = Executor.gather + + async def _gather_traced(self: Any, *coros: Any) -> List[Any]: + wrapped = [ + _trace_wrap_awaitable(rec, "executor.gather.op", c, args={"idx": i, "total": len(coros)}) + for i, c in enumerate(coros) + ] + task = asyncio.current_task() + tid = id(task) if task is not None else 0 + rec.begin("executor.gather", tid=tid, args={"n": len(coros)}) + try: + return await original_gather(self, *wrapped) + finally: + rec.end("executor.gather", tid=tid) + + monkeypatch.setattr(Executor, "gather", _gather_traced) + + # Patch Executor.run to show sequential refill loops vs plan execution. + original_run = Executor.run + + async def _run_traced( + self: Any, + node: Any, + ctx: Any, + limit: int, + next_page: Any, + **params: Any, + ) -> Any: + task = asyncio.current_task() + tid = id(task) if task is not None else 0 + node_type = getattr(node, "type", node.__class__.__name__) + node_id = getattr(node, "merger_id", getattr(node, "subfeed_id", None)) + rec.begin( + "executor.run_node", + tid=tid, + args={"node_type": node_type, "node_id": node_id, "limit": int(limit)}, + ) + try: + return await original_run(self, node, ctx, limit, next_page, **params) + finally: + rec.end("executor.run_node", tid=tid) + + monkeypatch.setattr(Executor, "run", _run_traced) + + # --- run: fresh request + next_page --- + limit = 12 + np0 = FeedResultNextPage(data={}) + + async with LoopBlockMonitor(sample_interval_s=0.01, block_threshold_s=0.05) as monitor: + res1 = await asyncio.wait_for( + merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=limit, + next_page=np0, + redis_client=redis_client, + ), + timeout=15, + ) + res2 = await asyncio.wait_for( + merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=limit, + next_page=res1.next_page, + redis_client=redis_client, + ), + timeout=15, + ) + + # Sanity: we should fill the page and maintain dedup invariants. + assert len(res1.data) == limit + assert len({x["id"] for x in res1.data}) == limit + assert len(res2.data) == limit + assert len({x["id"] for x in res2.data}) == limit + + # Hard assertion: leaf calls must overlap (async concurrency), not serialize. + assert leaf_concurrency.peak > 1 + # Refill signal: with max_per_call=1, two page requests should trigger + # multiple extra positional calls to satisfy positional slots. + assert pos_leaf_calls["count"] > 2 + + # Primary signal: event-loop should remain responsive under load. + assert monitor.max_lag_s < 0.1 + + out = os.environ.get("SMARTFEED_CHROME_TRACE") + if out: + # Allow writing to an explicit file path, or to a directory. + out_path = out + if os.path.isdir(out_path): + out_path = os.path.join(out_path, "smartfeed_trace.json") + rec.instant("loop.max_lag", tid=0, args={"max_lag_s": monitor.max_lag_s, "blocks": len(monitor.block_events)}) + rec.write(out_path) + + # Keep references so this test remains useful in local debugging. + _ = tmp_path diff --git a/tests/test_cursor_and_refill_edges.py b/tests/test_cursor_and_refill_edges.py new file mode 100644 index 0000000..27b5b72 --- /dev/null +++ b/tests/test_cursor_and_refill_edges.py @@ -0,0 +1,92 @@ +import base64 +import zlib + +import pytest + +from smartfeed.execution.context import ExecutionContext, RefillExecutionSettings +from smartfeed.execution.executor import Executor +from smartfeed.feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage +from smartfeed.policies.dedup import DeduplicationPolicy +from smartfeed.policies.dedup_utils import decode_seen_from_cursor +from smartfeed.policies.seen_store import CursorSeenStore + + +class _DuplicateOnlyNode(BaseFeedConfigModel): + """A node that always returns the same duplicate item and never ends.""" + + type: str = "test_node" # satisfies pydantic + + def __init__(self, **data): + super().__init__(**data) + object.__setattr__(self, "calls", 0) + + async def get_data( # type: ignore[override] + self, + methods_dict, + user_id, + limit, + next_page, + redis_client=None, + ctx=None, + **params, + ) -> FeedResult: + object.__setattr__(self, "calls", int(getattr(self, "calls", 0)) + 1) + return FeedResult( + data=[{"id": "dup"} for _ in range(int(limit) or 1)], + next_page=FeedResultNextPage(data={}), + has_next_page=True, + ) + + +def _policy_with_seen_dup(*, max_refill_loops: int) -> tuple[DeduplicationPolicy, RefillExecutionSettings]: + store = CursorSeenStore.from_after( + after={"v": 2, "seen": [["dup", 10]]}, + cursor_compress=False, + cursor_max_keys=None, + ) + policy = DeduplicationPolicy( + dedup_key="id", + missing_key_policy="keep", + store=store, + seen_request_set=set(), + ) + settings = RefillExecutionSettings(overfetch_factor=1, max_refill_loops=max_refill_loops) + return policy, settings + + +@pytest.mark.asyncio +async def test_dedup_refill_stops_at_max_loops_when_only_duplicates() -> None: + node = _DuplicateOnlyNode() + executor = Executor() + + dedup_policy, settings = _policy_with_seen_dup(max_refill_loops=2) + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + ctx.dedup = dedup_policy + ctx.refill_settings = settings + + res = await executor.run(node, ctx, limit=3, next_page=FeedResultNextPage(data={})) + + assert res.data == [] + # initial call + 2 refill loops + assert getattr(node, "calls") == 3 + + +def test_decode_seen_from_cursor_raises_on_corrupt_compressed_payload() -> None: + # invalid base64 + with pytest.raises(Exception): + decode_seen_from_cursor({"v": 2, "c": "zlib+base64", "z": "not-base64"}) + + # base64 ok, zlib invalid + bad_zlib = base64.urlsafe_b64encode(b"not-a-zlib-stream").decode("ascii") + with pytest.raises(Exception): + decode_seen_from_cursor({"v": 2, "c": "zlib+base64", "z": bad_zlib}) + + # zlib ok, json invalid + bad_json = base64.urlsafe_b64encode(zlib.compress(b"not json")).decode("ascii") + with pytest.raises(Exception): + decode_seen_from_cursor({"v": 2, "c": "zlib+base64", "z": bad_json}) + + +def test_decode_seen_from_cursor_rejects_wrong_version_or_codec() -> None: + assert decode_seen_from_cursor({"v": 1, "c": "zlib+base64", "z": ""}) == {} + assert decode_seen_from_cursor({"v": 2, "c": "other", "z": ""}) == {} diff --git a/tests/test_dedup_policy_unit.py b/tests/test_dedup_policy_unit.py new file mode 100644 index 0000000..b846f6b --- /dev/null +++ b/tests/test_dedup_policy_unit.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import pytest + +from smartfeed.policies.dedup import DeduplicationPolicy +from smartfeed.policies.seen_store import CursorSeenStore + + +def _policy(*, dedup_key: Optional[str], missing_key_policy: str, after: Any = None) -> DeduplicationPolicy: + store = CursorSeenStore.from_after(after=after, cursor_compress=False, cursor_max_keys=None) + return DeduplicationPolicy( + dedup_key=dedup_key, + missing_key_policy=missing_key_policy, # type: ignore[arg-type] + store=store, + seen_request_set=set(), + ) + + +@pytest.mark.asyncio +async def test_accept_batch_dedups_within_request_and_respects_existing_priority() -> None: + # Existing seen key with high priority should block lower/equal priority + existing_after = {"v": 2, "seen": [["1", 10]]} + policy = _policy(dedup_key="id", missing_key_policy="keep", after=existing_after) + + items = [{"id": 1}, {"id": 1}, {"id": 2}] + accepted, inspected = await policy.accept_batch(items=items, priority=5) + + assert inspected == 3 + assert accepted == [{"id": 2}] + + +@pytest.mark.asyncio +async def test_accept_batch_missing_key_policies() -> None: + items = [{"id": 1}, {"nope": 2}] + + policy_drop = _policy(dedup_key="id", missing_key_policy="drop") + accepted_drop, _ = await policy_drop.accept_batch(items=items, priority=0) + assert accepted_drop == [{"id": 1}] + + policy_error = _policy(dedup_key="id", missing_key_policy="error") + with pytest.raises(AssertionError): + await policy_error.accept_batch(items=items, priority=0) + + +@dataclass +class _Owner: + dedup_priority: int + + +@pytest.mark.asyncio +async def test_arbitrate_owner_buffers_prefers_higher_priority_owner() -> None: + policy = _policy(dedup_key="id", missing_key_policy="keep") + + owner_low = _Owner(dedup_priority=1) + owner_high = _Owner(dedup_priority=2) + + owners = [owner_low, owner_high] + owner_rank: Dict[int, int] = {id(owner_low): 0, id(owner_high): 1} + + shared = {"id": "same"} + owner_buffers: Dict[int, List[Any]] = { + id(owner_low): [shared, {"id": "low_only"}], + id(owner_high): [shared, {"id": "high_only"}], + } + + per_owner = await policy.arbitrate_owner_buffers( + owners=owners, + owner_buffers=owner_buffers, + owner_rank=owner_rank, + ) + + assert per_owner[id(owner_high)] == [shared, {"id": "high_only"}] + assert per_owner[id(owner_low)] == [{"id": "low_only"}] diff --git a/tests/test_dedup_utils.py b/tests/test_dedup_utils.py new file mode 100644 index 0000000..06150ad --- /dev/null +++ b/tests/test_dedup_utils.py @@ -0,0 +1,73 @@ +from typing import Any, List + +import pytest + +from smartfeed.feed_models import _redis_call +from smartfeed.policies.dedup_utils import decode_seen_from_cursor, encode_seen_for_cursor, redis_zmscore +from tests.fixtures.redis import redis_client + + +class _RedisNoZmscore: + """Wrapper around a real sync redis client to force the pipeline fallback. + + This keeps the backend "real Redis" while exercising the no-zmscore branch. + """ + + zmscore = None # type: ignore[assignment] + + def __init__(self, client: Any) -> None: + self._client = client + + def pipeline(self) -> Any: + return self._client.pipeline() + + +def test_encode_decode_seen_cursor_compressed_roundtrip_and_truncation() -> None: + seen_updates = [("a", 1), ("b", 2), ("c", 3)] + + encoded = encode_seen_for_cursor(seen_updates, cursor_compress=True, cursor_max_keys=None) + decoded = decode_seen_from_cursor(encoded) + assert decoded == {"a": 1, "b": 2, "c": 3} + + encoded_trunc = encode_seen_for_cursor(seen_updates, cursor_compress=True, cursor_max_keys=2) + decoded_trunc = decode_seen_from_cursor(encoded_trunc) + assert decoded_trunc == {"b": 2, "c": 3} + + +def test_decode_seen_cursor_v2_only() -> None: + assert decode_seen_from_cursor(None) == {} + + # Supported v2 uncompressed format + assert decode_seen_from_cursor({"v": 2, "seen": [["a", 9], ["b", 1]]}) == {"a": 9, "b": 1} + + # Legacy/unknown shapes are intentionally rejected + assert decode_seen_from_cursor(["x", "y"]) == {} + assert decode_seen_from_cursor({"a": 1}) == {} + assert decode_seen_from_cursor({"v": 1, "seen": [["a", 1]]}) == {} + + +@pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) +@pytest.mark.asyncio +async def test_redis_zmscore_native(redis_client) -> None: + key = "test_zmscore_native" + await _redis_call(redis_client, "delete", key) + await _redis_call(redis_client, "zadd", key, mapping={"a": 1.0, "b": 2.0}) + + res = await redis_zmscore(redis_client, key, ["a", "missing", "b"]) + assert res == [1.0, None, 2.0] + + await _redis_call(redis_client, "delete", key) + + +@pytest.mark.parametrize("redis_client", ["sync"], indirect=True) +@pytest.mark.asyncio +async def test_redis_zmscore_pipeline_fallback_for_sync_client_without_zmscore(redis_client) -> None: + key = "test_zmscore_fallback" + await _redis_call(redis_client, "delete", key) + await _redis_call(redis_client, "zadd", key, mapping={"a": 1.0, "b": 2.0}) + + wrapped = _RedisNoZmscore(redis_client) + res = await redis_zmscore(wrapped, key, ["a", "missing", "b"]) + assert res == [1.0, None, 2.0] + + await _redis_call(redis_client, "delete", key) diff --git a/tests/test_executor_slots_plan_invariants.py b/tests/test_executor_slots_plan_invariants.py new file mode 100644 index 0000000..4b774b4 --- /dev/null +++ b/tests/test_executor_slots_plan_invariants.py @@ -0,0 +1,297 @@ +import pytest + +from smartfeed.execution.context import ExecutionContext, RefillExecutionSettings +from smartfeed.execution.executor import Executor +from smartfeed.execution.plans import SlotSpec, SlotsPlan +from smartfeed.feed_models import BaseFeedConfigModel, FeedResult, FeedResultNextPage, FeedResultNextPageInside +from smartfeed.policies.dedup import DeduplicationPolicy +from smartfeed.policies.seen_store import CursorSeenStore + + +class _Owner(BaseFeedConfigModel): + type: str = "test_owner" + + def __init__(self, *, name: str, **data): + super().__init__(**data) + object.__setattr__(self, "name", name) + object.__setattr__(self, "last_limit", None) + object.__setattr__(self, "calls", 0) + + async def get_data( # type: ignore[override] + self, + methods_dict, + user_id, + limit, + next_page, + redis_client=None, + ctx=None, + **params, + ) -> FeedResult: + object.__setattr__(self, "calls", int(getattr(self, "calls", 0)) + 1) + object.__setattr__(self, "last_limit", int(limit)) + return FeedResult(data=[self.name] * int(limit), next_page=FeedResultNextPage(data={}), has_next_page=False) + + +class _PagedOwner(BaseFeedConfigModel): + type: str = "test_paged_owner" + subfeed_id: str + total: int = 10 + + def __init__(self, *, subfeed_id: str, total: int = 10, **data): + super().__init__(subfeed_id=subfeed_id, total=total, **data) + object.__setattr__(self, "calls", 0) + object.__setattr__(self, "limits", []) + + async def get_data( # type: ignore[override] + self, + methods_dict, + user_id, + limit, + next_page, + redis_client=None, + ctx=None, + **params, + ) -> FeedResult: + object.__setattr__(self, "calls", int(getattr(self, "calls", 0)) + 1) + limits = list(getattr(self, "limits", [])) + limits.append(int(limit)) + object.__setattr__(self, "limits", limits) + + entry = next_page.data.get(self.subfeed_id) + offset = int(entry.after) if (entry is not None and isinstance(entry.after, int)) else 0 + + take = max(0, min(int(limit), int(self.total) - offset)) + data = [{"id": f"{self.subfeed_id}_{i}"} for i in range(offset + 1, offset + take + 1)] + new_after = offset + take + + next_page.data[self.subfeed_id] = FeedResultNextPageInside( + page=(entry.page + 1 if entry is not None else 2), + after=new_after, + ) + return FeedResult( + data=data, + next_page=next_page, + has_next_page=bool(new_after < int(self.total)), + ) + + +def _dedup_policy() -> DeduplicationPolicy: + store = CursorSeenStore.from_after(after=None, cursor_compress=False, cursor_max_keys=None) + return DeduplicationPolicy( + dedup_key="id", + missing_key_policy="keep", # type: ignore[arg-type] + store=store, + seen_request_set=set(), + ) + + +@pytest.mark.asyncio +async def test_slots_plan_limit_le_zero_calls_assemble_only() -> None: + executor = Executor() + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + + owner = _Owner(name="x") + + called = {"assemble": 0} + + def assemble(output, next_page, owner_results): + called["assemble"] += 1 + assert output == [] + assert owner_results == {} + return FeedResult(data=output, next_page=next_page, has_next_page=False) + + plan = SlotsPlan( + ctx=ctx, + limit=0, + next_page=FeedResultNextPage(data={}), + params={}, + slots=[SlotSpec(owner=owner, max_count=10)], + assemble=assemble, + ) + + res = await executor.execute_plan(plan) + assert res.data == [] + assert called["assemble"] == 1 + assert getattr(owner, "calls") == 0 + + +@pytest.mark.asyncio +async def test_slots_plan_owner_fetch_limits_overrides_demand() -> None: + executor = Executor() + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + + owner = _Owner(name="x") + + def assemble(output, next_page, owner_results): + return FeedResult(data=output, next_page=next_page, has_next_page=False) + + plan = SlotsPlan( + ctx=ctx, + limit=5, + next_page=FeedResultNextPage(data={}), + params={}, + slots=[SlotSpec(owner=owner, max_count=5)], + assemble=assemble, + owner_fetch_limits={id(owner): 1}, + ) + + res = await executor.execute_plan(plan) + assert getattr(owner, "calls") == 1 + assert getattr(owner, "last_limit") == 1 + assert res.data == ["x"] + + +@pytest.mark.asyncio +async def test_slots_plan_no_ops_path_assemble_still_runs() -> None: + executor = Executor() + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + + owner = _Owner(name="x") + + called = {"assemble": 0} + + def assemble(output, next_page, owner_results): + called["assemble"] += 1 + assert output == [] + assert owner_results == {} + return FeedResult(data=[], next_page=next_page, has_next_page=False) + + plan = SlotsPlan( + ctx=ctx, + limit=5, + next_page=FeedResultNextPage(data={}), + params={}, + slots=[SlotSpec(owner=owner, max_count=0)], + assemble=assemble, + ) + + res = await executor.execute_plan(plan) + assert res.data == [] + assert called["assemble"] == 1 + assert getattr(owner, "calls") == 0 + + +@pytest.mark.asyncio +async def test_slots_plan_quota_deficit_triggers_refill_wave() -> None: + executor = Executor() + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + ctx.dedup = _dedup_policy() + + a = _PagedOwner(subfeed_id="a", total=10) + b = _PagedOwner(subfeed_id="b", total=10) + + def assemble(output, next_page, owner_results): + return FeedResult(data=output, next_page=next_page, has_next_page=False) + + plan = SlotsPlan( + ctx=ctx, + limit=6, + next_page=FeedResultNextPage( + data={ + "a": FeedResultNextPageInside(page=1, after=0), + "b": FeedResultNextPageInside(page=1, after=0), + } + ), + params={}, + slots=[ + SlotSpec(owner=a, max_count=3), + SlotSpec(owner=b, max_count=3), + ], + assemble=assemble, + # Force an initial under-fetch for owner a (quota deficit). + owner_fetch_limits={id(a): 1}, + ) + + res = await executor.execute_plan(plan) + + # a should be refilled from 1 -> 3 items + assert getattr(a, "calls") >= 2 + assert res.data[:3] == [{"id": "a_1"}, {"id": "a_2"}, {"id": "a_3"}] + assert res.data[3:] == [{"id": "b_1"}, {"id": "b_2"}, {"id": "b_3"}] + + +@pytest.mark.asyncio +async def test_slots_plan_quota_deficit_stops_refill_when_owner_exhausts() -> None: + executor = Executor() + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + ctx.dedup = _dedup_policy() + + # Owner a can never satisfy its full slot quota. + a = _PagedOwner(subfeed_id="a", total=2) + b = _PagedOwner(subfeed_id="b", total=10) + + def assemble(output, next_page, owner_results): + return FeedResult(data=output, next_page=next_page, has_next_page=False) + + plan = SlotsPlan( + ctx=ctx, + limit=6, + next_page=FeedResultNextPage( + data={ + "a": FeedResultNextPageInside(page=1, after=0), + "b": FeedResultNextPageInside(page=1, after=0), + } + ), + params={}, + slots=[ + SlotSpec(owner=a, max_count=3), + SlotSpec(owner=b, max_count=3), + ], + assemble=assemble, + # Force an initial under-fetch to create a quota deficit for a. + owner_fetch_limits={id(a): 1}, + ) + + res = await executor.execute_plan(plan) + + # a is exhausted after returning 2 total items; refill should stop. + assert getattr(a, "calls") == 2 + assert getattr(a, "limits") == [1, 2] + assert getattr(b, "calls") == 1 + + assert res.data == [ + {"id": "a_1"}, + {"id": "a_2"}, + {"id": "b_1"}, + {"id": "b_2"}, + {"id": "b_3"}, + ] + + +@pytest.mark.asyncio +async def test_slots_plan_quota_deficit_refills_without_dedup_when_refill_settings_present() -> None: + executor = Executor() + ctx = ExecutionContext(methods_dict={}, user_id="u", executor=executor) + ctx.refill_settings = RefillExecutionSettings(overfetch_factor=3, max_refill_loops=10) + + a = _PagedOwner(subfeed_id="a", total=10) + b = _PagedOwner(subfeed_id="b", total=10) + + def assemble(output, next_page, owner_results): + return FeedResult(data=output, next_page=next_page, has_next_page=False) + + plan = SlotsPlan( + ctx=ctx, + limit=6, + next_page=FeedResultNextPage( + data={ + "a": FeedResultNextPageInside(page=1, after=0), + "b": FeedResultNextPageInside(page=1, after=0), + } + ), + params={}, + slots=[ + SlotSpec(owner=a, max_count=3), + SlotSpec(owner=b, max_count=3), + ], + # force an initial under-fetch for owner a. + owner_fetch_limits={id(a): 1}, + assemble=assemble, + ) + + res = await executor.execute_plan(plan) + + # refill must still happen even when dedup policy is absent. + assert getattr(a, "calls") >= 2 + assert res.data[:3] == [{"id": "a_1"}, {"id": "a_2"}, {"id": "a_3"}] + assert res.data[3:] == [{"id": "b_1"}, {"id": "b_2"}, {"id": "b_3"}] diff --git a/tests/test_manager_params.py b/tests/test_manager_params.py new file mode 100644 index 0000000..a62d160 --- /dev/null +++ b/tests/test_manager_params.py @@ -0,0 +1,40 @@ +import pytest + +from smartfeed.manager import FeedManager +from smartfeed.schemas import FeedResultClient, FeedResultNextPage, FeedResultNextPageInside + + +async def meta_method( + user_id: str, + limit: int, + next_page: FeedResultNextPageInside, + meta: dict, +) -> FeedResultClient: + assert meta["tag"] == "alpha" + take = int(meta.get("take", limit)) + data = [f"{user_id}:{meta['tag']}"] * min(limit, take) + next_page.after = None + next_page.page += 1 + return FeedResultClient(data=data, next_page=next_page, has_next_page=False) + + +@pytest.mark.asyncio +async def test_manager_passes_params_to_subfeed() -> None: + config = { + "version": "1", + "feed": { + "subfeed_id": "sf_meta", + "type": "subfeed", + "method_name": "meta_method", + }, + } + + manager = FeedManager(config=config, methods_dict={"meta_method": meta_method}) + result = await manager.get_data( + user_id="u1", + limit=5, + next_page=FeedResultNextPage(data={}), + meta={"tag": "alpha", "take": 2}, + ) + + assert result.data == ["u1:alpha", "u1:alpha"] diff --git a/tests/test_merger_append.py b/tests/test_merger_append.py index e9db5c7..290dd04 100644 --- a/tests/test_merger_append.py +++ b/tests/test_merger_append.py @@ -1,8 +1,11 @@ +import copy + import pytest from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerAppend from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_APPEND_CONFIG +from tests.utils import parse_model @pytest.mark.asyncio @@ -11,7 +14,7 @@ async def test_merger_append() -> None: Тест для проверки получения данных из append мерджера. """ - merger_append = MergerAppend.parse_obj(MERGER_APPEND_CONFIG) + merger_append = parse_model(MergerAppend, MERGER_APPEND_CONFIG) merger_append_res = await merger_append.get_data( methods_dict=METHODS_DICT, limit=11, @@ -28,7 +31,7 @@ async def test_merger_append_with_item_1_page_2() -> None: Тест для проверки получения данных из append мерджера с курсором пагинации первого субфида. """ - merger_append = MergerAppend.parse_obj(MERGER_APPEND_CONFIG) + merger_append = parse_model(MergerAppend, MERGER_APPEND_CONFIG) merger_append_res = await merger_append.get_data( methods_dict=METHODS_DICT, limit=11, @@ -41,3 +44,23 @@ async def test_merger_append_with_item_1_page_2() -> None: assert merger_append_res.data == ["x_6", "x_7", "x_8", "x_9", "x_10", "x_1", "x_2", "x_3", "x_4", "x_5", "x_6"] assert merger_append_res.next_page.data["subfeed_merger_append_example"].page == 3 assert merger_append_res.next_page.data["subfeed_merger_append_example"].after == "x_10" + + +@pytest.mark.asyncio +async def test_merger_append_when_one_leaf_is_empty() -> None: + config = copy.deepcopy(MERGER_APPEND_CONFIG) + # Make the second leaf return no data + has_next_page=False. + config["items"][1]["method_name"] = "empty" + + merger_append = parse_model(MergerAppend, config) + res = await merger_append.get_data( + methods_dict=METHODS_DICT, + limit=11, + next_page=FeedResultNextPage(data={}), + user_id="x", + ) + + # Only the first subfeed contributes (it is capped to 5 by config). + assert res.data == ["x_1", "x_2", "x_3", "x_4", "x_5"] + # First subfeed's example method still reports more pages. + assert res.has_next_page is True diff --git a/tests/test_merger_append_distribute.py b/tests/test_merger_append_distribute.py index 6bb1782..bc4878b 100644 --- a/tests/test_merger_append_distribute.py +++ b/tests/test_merger_append_distribute.py @@ -3,6 +3,7 @@ from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerAppendDistribute from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_APPEND_DISTRIBUTE_CONFIG +from tests.utils import parse_model @pytest.mark.asyncio @@ -11,7 +12,7 @@ async def test_merger_disturbed_append() -> None: Тест для проверки получения данных из append мерджера. """ - merger_distributed = MergerAppendDistribute.parse_obj(MERGER_APPEND_DISTRIBUTE_CONFIG) + merger_distributed = parse_model(MergerAppendDistribute, MERGER_APPEND_DISTRIBUTE_CONFIG) merger_distributed_res = await merger_distributed.get_data( methods_dict=METHODS_DICT, limit=20, @@ -31,7 +32,7 @@ async def test_merger_append_with_item_1_page_2() -> None: """ Тест для проверки получения данных из append мерджера с курсором пагинации первого субфида. """ - merger_distributed = MergerAppendDistribute.parse_obj(MERGER_APPEND_DISTRIBUTE_CONFIG) + merger_distributed = parse_model(MergerAppendDistribute, MERGER_APPEND_DISTRIBUTE_CONFIG) merger_distributed_res = await merger_distributed.get_data( methods_dict=METHODS_DICT, limit=11, diff --git a/tests/test_merger_deduplication.py b/tests/test_merger_deduplication.py new file mode 100644 index 0000000..101eaac --- /dev/null +++ b/tests/test_merger_deduplication.py @@ -0,0 +1,911 @@ +import asyncio + +import pytest + +from smartfeed.feed_models import _redis_call +from smartfeed.schemas import FeedResultClient, FeedResultNextPage, FeedResultNextPageInside, MergerDeduplication +from tests.fixtures import dedup_helpers as dh +from tests.fixtures.redis import redis_client # noqa: F401 +from tests.utils import parse_model + +PROFILES_B_1_TO_8 = { + "p0": [{"id": 1, "src": "B"}, {"id": 3, "src": "B"}, {"id": 5, "src": "B"}, {"id": 7, "src": "B"}], + "p1": [{"id": 2, "src": "B"}, {"id": 4, "src": "B"}, {"id": 6, "src": "B"}, {"id": 8, "src": "B"}], +} + + +def _assert_winning_src_for_ids(data, ids, expected_src: str) -> None: + winning = {item["id"]: item["src"] for item in data} + assert all(winning[i] == expected_src for i in ids if i in winning) + + +def _make_items_by_ids(src: str, ids, *, user_id_mod: int): + return [{"id": i, "user_id": f"u{i % user_id_mod}", "src": src} for i in ids] + + +@pytest.mark.asyncio +async def test_dedup_positional_slot_ownership_cursor_backend() -> None: + """Positional slots must remain owned by the positional branch. + + Deduplication must not drop items *after* the positional merge (which would shift indices). + Instead, duplicates must be skipped inside the leaf source that owns the slot. + """ + + # Default branch has early ids 1..3, which will be seen first. + default_items = dh.make_items("default", 1, 300) + + # Positional branch starts with duplicates 1..3; it must skip them and fetch 4.. instead. + positional_items = dh.make_items("pos", 1, 300) + + methods_dict = { + "default": dh.make_offset_paged_method(default_items), + "pos": dh.make_offset_paged_method(positional_items), + } + + config = dh._dedup_config( + "dedup_wrapper", + dh._positional_config( + "positional_mix", + # Ensure positional inserts exist on both pages for limit=6: + # page1 uses (1,3,5), page2 uses (7,9,11) which map to the same in-page slots. + positions=[1, 3, 5, 7, 9, 11], + positional=dh._subfeed("sf_pos", "pos"), + default=dh._subfeed("sf_default", "default"), + ), + max_refill_loops=20, + ) + + merger = parse_model(MergerDeduplication, config) + + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 6) + + assert len(res_1.data) == 6 + dh._assert_no_dupes_in_page(res_1.data) + + # Slot ownership: configured positions [1,3,5] are the positional branch. + dh._assert_sources_at_positions(res_1.data, [1, 3, 5], "pos") + + # Next page: still no overlap across pages, and positional slots remain owned. + assert len(res_2.data) == 6 + dh._assert_two_pages_no_dupes(res_1, res_2) + dh._assert_sources_at_positions(res_2.data, [1, 3, 5], "pos") + + dh._assert_cursor_monotonic_if_present( + res_1, res_2, keys=["sf_pos", "sf_default", "positional_mix", "dedup_wrapper"] + ) + + +@pytest.mark.asyncio +async def test_dedup_percentage_slot_ownership_cursor_backend() -> None: + """Percentage mixing order must be preserved even with duplicates across sources.""" + + # A is called first by the percentage merger; its ids will be seen before B. + a_items = dh.make_items("A", 1, 300) + + # B starts with duplicates 1..3; it must skip them and fetch unique tail items. + # Same IDs as A to force cross-source duplicates. + b_items = dh.make_items("B", 1, 300) + + config, methods_dict, _, _ = dh._build_two_subfeed_dedup_merger( + items_a=a_items, + items_b=b_items, + merger_id="dedup_wrapper_pct", + child_builder=lambda sf_a, sf_b: dh._percentage_config( + "pct_mix", + items=dh._percentage_items(sf_a, sf_b), + ), + ) + merger = parse_model(MergerDeduplication, config) + + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 10) + + assert len(res_1.data) == 10 + dh._assert_no_dupes_in_page(res_1.data) + + # Slot ownership: percentage merge alternates when list sizes are equal. + assert dh._sources(res_1.data)[:4] == ["A", "B", "A", "B"] + + assert len(res_2.data) == 10 + dh._assert_two_pages_no_dupes(res_1, res_2) + + assert dh._sources(res_2.data)[:2] == ["A", "B"] + + dh._assert_cursor_monotonic_if_present(res_1, res_2, keys=["sf_a", "sf_b", "pct_mix", "dedup_wrapper_pct"]) + + +@pytest.mark.asyncio +async def test_dedup_deep_tree_cursor_backend() -> None: + """Dedup must work through deep merger trees (wrapping leaf methods).""" + + # Leaf sources: intentionally overlapping ids across different leaves. + p_items = dh.make_items("P", 1, 30) + d1_items = dh.make_items("D1", 1, 30) # overlaps P + d2_items = dh.make_items("D2", 1, 30, id_offset=100) + + config, methods_dict = dh._build_deep_positional_pct_dedup_merger( + items_p=p_items, + items_d1=d1_items, + items_d2=d2_items, + dedup_merger_id="dedup_deep", + pos_merger_id="pos_deep", + pct_merger_id="pct_deep", + # Ensure positional positions exist on both page 1 (1,4) and page 2 (9,12) for limit=8. + positions=[1, 4, 9, 12], + ) + + merger = parse_model(MergerDeduplication, config) + + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 8) + + assert len(res_1.data) == 8 + dh._assert_no_dupes_in_page(res_1.data) + + # Positional ownership must hold even with deep defaults. + dh._assert_sources_at_positions(res_1.data, [1, 4], "P") + + assert len(res_2.data) == 8 + dh._assert_two_pages_no_dupes(res_1, res_2) + + dh._assert_sources_at_positions(res_2.data, [1, 4], "P") + + +@pytest.mark.asyncio +async def test_dedup_nested_positional_refill_not_masked_by_parent_append() -> None: + """Nested positional refills must run even when parent append can fill the page. + + Regression: + - parent dedup wrapper executes append owners with dedup disabled in owner ctx + - positional child under-fetches (`max_per_call=1`) and needs internal slot refills + - if those refills are skipped, append sibling backfills the page and positional slots are lost + """ + + items_default = dh.make_items("D", 1, 400, id_offset=1_000) + items_pos = dh.make_items("P", 1, 400, id_offset=10_000) + items_fill = dh.make_items("F", 1, 400, id_offset=20_000) + + pos_calls = {"count": 0} + pos_base = dh.make_offset_paged_method(items_pos, max_per_call=1) + + async def _pos_method(user_id, limit, next_page, **kwargs): # pylint: disable=unused-argument + pos_calls["count"] += 1 + return await pos_base(user_id, limit, next_page, **kwargs) + + methods_dict = { + "default": dh.make_offset_paged_method(items_default), + "pos": _pos_method, + "fill": dh.make_offset_paged_method(items_fill), + } + + config = dh._dedup_config( + "dedup_nested_refill", + dh._append_config( + "append_nested_refill", + [ + dh._positional_config( + "pos_nested_refill", + positions=[2, 4, 6, 8, 10, 12], + positional=dh._subfeed("sf_pos_nested", "pos"), + default=dh._subfeed("sf_default_nested", "default"), + ), + dh._subfeed("sf_fill_nested", "fill"), + ], + ), + dedup_key="id", + state_backend="cursor", + overfetch_factor=3, + max_refill_loops=50, + ) + + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=12, + next_page=FeedResultNextPage(data={}), + ) + + assert len(res.data) == 12 + dh._assert_no_dupes_in_page(res.data) + dh._assert_sources_at_positions(res.data, [2, 4, 6, 8, 10, 12], "P") + assert "F" not in set(dh._sources(res.data)) + assert pos_calls["count"] > 1 + + +@pytest.mark.asyncio +async def test_dedup_nested_percentage_refill_not_masked_by_parent_append() -> None: + """Nested percentage refills must run even when parent append can fill.""" + + items_a = dh.make_items("A", 1, 400, id_offset=1_000) + items_b = dh.make_items("B", 1, 400, id_offset=10_000) + items_fill = dh.make_items("F", 1, 400, id_offset=20_000) + + b_calls = {"count": 0} + b_base = dh.make_offset_paged_method(items_b, max_per_call=1) + + async def _b_method(user_id, limit, next_page, **kwargs): # pylint: disable=unused-argument + b_calls["count"] += 1 + return await b_base(user_id, limit, next_page, **kwargs) + + methods_dict = { + "a": dh.make_offset_paged_method(items_a), + "b": _b_method, + "fill": dh.make_offset_paged_method(items_fill), + } + + percentage_cfg = dh._percentage_config( + "pct_nested_refill", + items=dh._percentage_items( + dh._subfeed("sf_a_nested", "a"), + dh._subfeed("sf_b_nested", "b"), + first_pct=50, + second_pct=50, + ), + ) + + config = dh._dedup_config( + "dedup_nested_pct_refill", + dh._append_config( + "append_nested_pct_refill", + [percentage_cfg, dh._subfeed("sf_fill_nested_pct", "fill")], + ), + dedup_key="id", + state_backend="cursor", + overfetch_factor=3, + max_refill_loops=50, + ) + + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=12, + next_page=FeedResultNextPage(data={}), + ) + + assert len(res.data) == 12 + dh._assert_no_dupes_in_page(res.data) + assert "F" not in set(dh._sources(res.data)) + assert dh._sources(res.data).count("A") == 6 + assert dh._sources(res.data).count("B") == 6 + assert b_calls["count"] > 1 + + +@pytest.mark.parametrize( + "merger_type", + [ + "merger_append", + "merger_distribute", + "merger_positional", + "merger_percentage", + "merger_percentage_gradient", + ], +) +@pytest.mark.asyncio +async def test_dedup_deletion_priority_works_for_deep_trees_all_merger_types(merger_type: str) -> None: + """Deletion priority must work even in deep trees, across merger types. + + For overlapping ids, higher dedup_priority leaf must supply the winning entity. + """ + + # For mixing mergers (percentage/gradient/positional), identical id ranges are enough: the + # high-priority leaf should claim the first chunk of ids and the other leaf must skip them. + # + # For append/distribute, we must ensure BOTH branches contribute to the output (otherwise + # "priority" is unobservable because earlier branches can fill the page). We do that by + # making the low branch short and duplicate-heavy. + if merger_type in {"merger_append", "merger_distribute"}: + low_items = _make_items_by_ids("low", [1, 2, 3, 1000, 1001], user_id_mod=3) + high_items = dh.make_items("high", 1, 200, user_id_mod=3) + else: + low_items = dh.make_items("low", 1, 200, user_id_mod=3) + high_items = dh.make_items("high", 1, 200, user_id_mod=3) + + methods_dict = { + "low": dh.make_offset_paged_method(low_items), + "high": dh.make_offset_paged_method(high_items), + } + + deep_tree = dh._build_deep_priority_tree_for_merger_type(merger_type=merger_type) + config = dh._dedup_config(f"dedup_priority_deep_{merger_type}", deep_tree) + + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=10, + next_page=FeedResultNextPage(data={}), + ) + + dh._assert_no_dupes_in_page(res.data) + + # For append/distribute, priority is only observable if both branches contribute something. + if merger_type in {"merger_append", "merger_distribute"}: + sources = set(dh._sources(res.data)) + assert "high" in sources + assert "low" in sources + + # Priority is about which source wins overlapping ids (not about output order). + _assert_winning_src_for_ids(res.data, range(1, 6), "high") + + # Placement invariant for positional: positional slots must still be owned by positional branch. + if merger_type == "merger_positional": + sources = dh._sources(res.data) + assert sources[0] == "high" + assert sources[2] == "high" + assert sources[4] == "high" + + +@pytest.mark.asyncio +async def test_dedup_overfetch_factor_does_not_skip_unseen_items_in_deep_tree_cursors() -> None: + """When overfetch_factor>1, leaf cursors must be rewound to inspected count. + + This is a regression test for the "safe overfetch" logic: we may request more + than we need from a leaf source, but we must not advance that leaf cursor past + un-inspected items. In a deep tree, this must hold for all descendant SubFeeds. + """ + + p_items = dh.make_items("P", 1, 200, id_offset=1000) + d1_items = dh.make_items("D1", 1, 200) + d2_items = dh.make_items("D2", 1, 200, id_offset=500) + + config, methods_dict = dh._build_deep_positional_pct_dedup_merger( + items_p=p_items, + items_d1=d1_items, + items_d2=d2_items, + dedup_merger_id="dedup_overfetch", + pos_merger_id="pos_overfetch", + pct_merger_id="pct_overfetch", + positions=[1, 4, 9, 12], + overfetch_factor=3, + ) + + merger = parse_model(MergerDeduplication, config) + + # Page 1/2 + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 8) + + assert len(res_1.data) == 8 + dh._assert_no_dupes_in_page(res_1.data) + + # Dedup merger cursor must exist and advance page. + assert "dedup_overfetch" in res_1.next_page.data + assert res_1.next_page.data["dedup_overfetch"].page == 2 + assert res_1.next_page.data["dedup_overfetch"].after is not None + + # Positional merger cursor must exist and advance page. + assert "pos_overfetch" in res_1.next_page.data + assert res_1.next_page.data["pos_overfetch"].page == 2 + + # Deep descendant cursors: positional leaf requests 2 items; percentage leaves request 4 each. + # With overfetch_factor=3, internal calls may request 6/12, but cursor must not advance that far. + assert res_1.next_page.data["sf_p"].after == 2 + assert res_1.next_page.data["sf_d1"].after == 4 + assert res_1.next_page.data["sf_d2"].after == 4 + + # Page 2 (monotonic advancement, still no over-advancement) + assert len(res_2.data) == 8 + dh._assert_two_pages_no_dupes(res_1, res_2) + + assert res_2.next_page.data["dedup_overfetch"].page == 3 + assert res_2.next_page.data["pos_overfetch"].page == 3 + + assert res_2.next_page.data["sf_p"].after == 4 + assert res_2.next_page.data["sf_d1"].after == 8 + assert res_2.next_page.data["sf_d2"].after == 8 + + dh._assert_cursor_monotonic_if_present( + res_1, + res_2, + keys=["sf_p", "sf_d1", "sf_d2", "pos_overfetch", "dedup_overfetch"], + ) + + +@pytest.mark.asyncio +async def test_dedup_page_zero_resets_seen_and_descendant_cursors() -> None: + items = dh.make_items("S", 1, 50) + methods_dict = {"s": dh.make_offset_paged_method(items)} + + config = dh._dedup_config("dedup_reset", dh._subfeed("sf_stream", "s")) + + merger = parse_model(MergerDeduplication, config) + + res_1 = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=5, + next_page=FeedResultNextPage(data={}), + ) + assert dh._ids(res_1.data) == [1, 2, 3, 4, 5] + + # Simulate client "full reload": page=0 for the dedup merger. + # Also include the stale descendant cursor; dedup should clear it. + res_2 = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=5, + next_page=FeedResultNextPage( + data={ + "dedup_reset": FeedResultNextPageInside(page=0, after=res_1.next_page.data["dedup_reset"].after), + # Use a deliberately bogus descendant cursor; the dedup wrapper must ignore/reset it. + "sf_stream": FeedResultNextPageInside(page=99, after=999), + } + ), + ) + + # Must restart from the beginning. + assert dh._ids(res_2.data) == [1, 2, 3, 4, 5] + # And must not propagate the bogus descendant cursor. + assert res_2.next_page.data["sf_stream"].after == 5 + assert res_2.next_page.data["sf_stream"].page == 2 + + +@pytest.mark.asyncio +async def test_dedup_cursor_backend_persists_seen_state_beyond_two_pages() -> None: + # First 2 pages are unique, then page 3 starts with duplicates from page 1. + items = dh.make_items("S", 1, 11) + dh.make_items("S", 1, 4) + dh.make_items("S", 11, 31) + methods_dict = {"s": dh.make_offset_paged_method(items)} + + config = dh._dedup_config( + "dedup_cursor_3p", + dh._subfeed("sf_stream", "s"), + state_backend="cursor", + ) + merger = parse_model(MergerDeduplication, config) + + res_1 = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=5, + next_page=FeedResultNextPage(data={}), + ) + res_2 = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=5, + next_page=res_1.next_page, + ) + res_3 = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=5, + next_page=res_2.next_page, + ) + + ids_1 = dh._ids(res_1.data) + ids_2 = dh._ids(res_2.data) + ids_3 = dh._ids(res_3.data) + + assert ids_1 == [1, 2, 3, 4, 5] + assert ids_2 == [6, 7, 8, 9, 10] + assert ids_3 == [11, 12, 13, 14, 15] + assert not (set(ids_1) & set(ids_2)) + assert not (set(ids_1) & set(ids_3)) + assert not (set(ids_2) & set(ids_3)) + + +@pytest.mark.asyncio +async def test_dedup_append_cursor_backend_across_pages_and_refill_advances_leaf_cursor_exactly() -> None: + """Append: across pages there is no overlap; refill advances cursors correctly. + + This uses a max_per_call=1 method for the duplicate-heavy leaf so the wrapper + must do multiple client calls (refill loops). + """ + + a_items = [{"id": 1, "src": "A"}, {"id": 2, "src": "A"}] + b_items = dh.make_items("B", 1, 50) + + config, methods_dict, _, _ = dh._build_two_subfeed_dedup_merger( + items_a=a_items, + items_b=b_items, + merger_id="dedup_append_pages", + child_builder=lambda sf_a, sf_b: dh._append_config("append_mix", [sf_a, sf_b]), + spec_b=dh._two_subfeed_spec(name="b", subfeed_id="sf_b", max_per_call=1), + dedup_kwargs={"max_refill_loops": 50}, + ) + merger = parse_model(MergerDeduplication, config) + + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 5) + + assert dh._ids(res_1.data) == [1, 2, 3, 4, 5] + assert dh._sources(res_1.data) == ["A", "A", "B", "B", "B"] + + assert res_1.next_page.data["dedup_append_pages"].page == 2 + + # In default arbitrate mode, B only needs to scan far enough to fill the remaining + # portion of the page after arbitration (here: 3 items: ids 3..5). + assert res_1.next_page.data["sf_b"].after == 5 + b_contributed = sum(1 for x in res_1.data if x.get("src") == "B") + assert res_1.next_page.data["sf_b"].after > b_contributed + + # A is exhausted after 2 reads; ensure cursor reflects that. + assert res_1.next_page.data["sf_a"].after == 2 + + assert len(res_2.data) == 5 + dh._assert_two_pages_no_dupes(res_1, res_2) + # Across two pages, B should have advanced exactly 5 more items. + assert res_2.next_page.data["sf_b"].after == 10 + + +@pytest.mark.asyncio +async def test_dedup_arbitrate_mode_runs_parallel_prefetch_and_arbitrates_winners() -> None: + started_a = asyncio.Event() + started_b = asyncio.Event() + release = asyncio.Event() + + items_a = dh.make_items("A", 1, 200) + items_b = dh.make_items("B", 1, 200) + + def make_method(*, items, started_event): + async def _method(user_id, limit, next_page, **kwargs): # pylint: disable=unused-argument + started_event.set() + await release.wait() + offset = int(next_page.after or 0) + data = items[offset : offset + limit] + next_page.after = offset + len(data) + next_page.page += 1 + return FeedResultClient(data=data, next_page=next_page, has_next_page=True) + + return _method + + methods_dict = { + "a": make_method(items=items_a, started_event=started_a), + "b": make_method(items=items_b, started_event=started_b), + } + + config = dh._dedup_config( + "dedup_arbitrate", + dh._percentage_config( + "pct", + items=dh._percentage_items(dh._subfeed("sf_a", "a"), dh._subfeed("sf_b", "b")), + ), + ) + + merger = parse_model(MergerDeduplication, config) + + task = asyncio.create_task( + merger.get_data(methods_dict=methods_dict, user_id="u", limit=10, next_page=FeedResultNextPage(data={})) + ) + + # If execution is sequential, one of these would time out. + await asyncio.wait_for(started_a.wait(), timeout=1) + await asyncio.wait_for(started_b.wait(), timeout=1) + release.set() + + res = await asyncio.wait_for(task, timeout=2) + + assert len(res.data) == 10 + dh._assert_no_dupes_in_page(res.data) + # With equal priorities, stable tie-breaker should pick A (first branch) for overlapping keys. + _assert_winning_src_for_ids(res.data, range(1, 6), "A") + + +@pytest.mark.asyncio +async def test_dedup_refill_loops_advance_dict_after_cursor_not_just_page() -> None: + """Dedup refill loops must correctly advance dict-shaped `after` cursors.""" + + # A produces ids 1,2. + a_items = [{"id": 1, "src": "A"}, {"id": 2, "src": "A"}] + + # B produces ids 1.. in round-robin across profiles; cursor is per-profile offsets. + b_profiles = PROFILES_B_1_TO_8 + + methods_dict = { + "a": dh.make_offset_paged_method(a_items), + "b": dh.make_profile_dict_after_method(b_profiles), + } + + # Use a percentage merger so B is asked for a small limit (2 items for limit=4). + # This forces refill loops when B's first batch is all duplicates. + config = dh._dedup_config( + "dedup_dict_after", + dh._percentage_config( + "pct_mix", + items=dh._percentage_items( + dh._subfeed("sf_a", "a", dedup_priority=100), + dh._subfeed("sf_b", "b", dedup_priority=0), + ), + ), + max_refill_loops=50, + ) + + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=4, + next_page=FeedResultNextPage(data={}), + ) + + assert len(res.data) == 4 + dh._assert_no_dupes_in_page(res.data) + assert set(dh._ids(res.data)) == {1, 2, 3, 4} + assert "sf_b" in res.next_page.data + assert isinstance(res.next_page.data["sf_b"].after, dict) + + # B contributed 2 items (3,4) but must have *read* 4 items (1..4) to skip duplicates. + b_after = res.next_page.data["sf_b"].after + read_count = sum(int(v) for v in b_after.values()) + assert read_count == 4 + + +@pytest.mark.asyncio +async def test_dedup_overfetch_does_not_overadvance_non_int_after_cursor() -> None: + """overfetch_factor must not cause over-advancement for non-rewindable cursors.""" + + # Single subfeed with dict after cursor; no dedup skips should happen. + profiles = PROFILES_B_1_TO_8 + + methods_dict = { + "b": dh.make_profile_dict_after_method(profiles), + } + + config = dh._dedup_config( + "dedup_nonint_overfetch", + dh._subfeed("sf_b", "b"), + overfetch_factor=5, + ) + + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data(methods_dict=methods_dict, user_id="u", limit=4, next_page=FeedResultNextPage(data={})) + + assert len(res.data) == 4 + after = res.next_page.data["sf_b"].after + assert isinstance(after, dict) + # If overfetch were incorrectly applied, we'd see more than 4 reads. + assert sum(int(v) for v in after.values()) == 4 + + +@pytest.mark.asyncio +async def test_dedup_overfetch_rewinds_offset_cursor_when_first_batch_all_duplicates() -> None: + """Overfetch should be safe: when we oversample, we must rewind offset cursors. + + Scenario: + - A (high priority) returns ids 1..5 + - B (low priority) initially returns only duplicates (1..5) + - On the next refill loop, B overfetches but must rewind `after` to inspected count + so it doesn't skip items. + """ + + items_a = dh.make_items("A", 1, 300) + items_b = dh.make_items("B", 1, 300) + + config, methods_dict, _, _ = dh._build_two_subfeed_dedup_merger( + items_a=items_a, + items_b=items_b, + merger_id="dedup_overfetch_rewind", + child_builder=lambda sf_a, sf_b: dh._percentage_config( + "pct_mix", + items=dh._percentage_items(sf_a, sf_b), + ), + spec_a=dh._two_subfeed_spec(dedup_priority=100), + spec_b=dh._two_subfeed_spec(name="b", subfeed_id="sf_b", dedup_priority=0), + dedup_kwargs={"overfetch_factor": 3, "max_refill_loops": 20}, + ) + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=10, + next_page=FeedResultNextPage(data={}), + ) + + assert len(res.data) == 10 + dh._assert_no_dupes_in_page(res.data) + + # A provides 1..5, B must provide 6..10. + _assert_winning_src_for_ids(res.data, range(1, 6), "A") + _assert_winning_src_for_ids(res.data, range(6, 11), "B") + + # Cursor rewind check: + # - First loop for B reads 5 duplicates -> after becomes 5 + # - Second loop overfetches, but must rewind to inspected 5 more -> after should end at 10 + assert res.next_page.data["sf_b"].after == 10 + + +@pytest.mark.parametrize( + "items_a,items_b,min_b_id", + [ + (dh.make_items("A", 1, 4, user_id_mod=2), dh.make_items("B", 1, 200, user_id_mod=2), 4), + (dh.make_items("A", 1, 200, user_id_mod=3), dh.make_items("B", 1, 200, user_id_mod=3), None), + ], +) +@pytest.mark.asyncio +async def test_dedup_distribute_cursor_backend_across_pages_preserves_source_refill(items_a, items_b, min_b_id) -> None: + """Distribute: duplicates skipped per-leaf and page slices don't overlap.""" + + config, methods_dict, _, _ = dh._build_two_subfeed_dedup_merger( + items_a=items_a, + items_b=items_b, + merger_id="dedup_dist_pages", + child_builder=lambda sf_a, sf_b: dh._distribute_config("dist", [sf_a, sf_b]), + ) + merger = parse_model(MergerDeduplication, config) + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 10) + + assert len(res_1.data) == 10 + assert len(res_2.data) == 10 + dh._assert_two_pages_no_dupes(res_1, res_2) + + # Placement/refill: B must skip duplicate ids 1..3 and still fill the page. + if min_b_id is not None: + b_ids_1 = [x["id"] for x in res_1.data if x.get("src") == "B"] + assert b_ids_1 and min(b_ids_1) >= min_b_id + + +@pytest.mark.asyncio +async def test_dedup_percentage_gradient_cursor_backend_across_pages() -> None: + a_items = dh.make_items("A", 1, 300) + b_items = dh.make_items("B", 1, 30) + dh.make_items("B", 1, 300, id_offset=1000) + + methods_dict = { + "a": dh.make_offset_paged_method(a_items), + "b": dh.make_offset_paged_method(b_items), + } + + config = dh._dedup_config( + "dedup_grad_pages", + dh._gradient_config( + "grad_mix", + item_from={"percentage": 60, "data": dh._subfeed("sf_a", "a")}, + item_to={"percentage": 40, "data": dh._subfeed("sf_b", "b")}, + step=20, + size_to_step=5, + shuffle=False, + ), + max_refill_loops=50, + ) + + merger = parse_model(MergerDeduplication, config) + res_1, res_2 = await dh._run_two_pages(merger, methods_dict, 10) + + dh._assert_two_pages_no_dupes(res_1, res_2) + + sources = dh._sources(res_1.data) + assert sources == ["A", "A", "A", "B", "B", "A", "A", "B", "B", "B"] + + # Gradient merger cursor should exist and advance. + assert res_1.next_page.data["grad_mix"].page == 2 + assert res_2.next_page.data["grad_mix"].page == 3 + + +@pytest.mark.parametrize( + "merger_id,custom_deduplication_key,items_a,items_b,child_builder", + [ + ( + "dedup_redis", + "t1", + dh.make_items("A", 1, 300), + dh.make_items("B", 1, 300), # Same IDs as A to force cross-source duplicates. + lambda sf_a, sf_b: dh._percentage_config("pct_mix", items=dh._percentage_items(sf_a, sf_b)), + ), + ( + "dedup_redis_append", + "t2", + dh.make_items("A", 1, 20), + dh.make_items("B", 1, 300), + lambda sf_a, sf_b: dh._append_config("append_mix", [sf_a, sf_b]), + ), + ], +) +@pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) +@pytest.mark.asyncio +async def test_dedup_redis_backend_cross_page( + redis_client, + merger_id, + custom_deduplication_key, + items_a, + items_b, + child_builder, +) -> None: + config, methods_dict, _, _ = dh._build_two_subfeed_dedup_merger( + items_a=items_a, + items_b=items_b, + merger_id=merger_id, + child_builder=child_builder, + dedup_kwargs={"state_backend": "redis", "state_ttl_seconds": 60}, + ) + merger = parse_model(MergerDeduplication, config) + + res_1, res_2 = await dh._run_two_pages( + merger, + methods_dict, + 10, + redis_client=redis_client, + custom_deduplication_key=custom_deduplication_key, + ) + + dh._assert_two_pages_no_dupes(res_1, res_2) + + # Redis backend should not store seen ids in cursor after. + assert merger_id in res_2.next_page.data + assert res_2.next_page.data[merger_id].after is None + + # Ensure state is persisted in Redis. + key = f"dedup:{merger_id}:u:{custom_deduplication_key}" + members = await _redis_call(redis_client, "zrange", key, 0, -1) + assert len(members) >= len(set(dh._ids(res_1.data) + dh._ids(res_2.data))) + + +@pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) +@pytest.mark.asyncio +async def test_dedup_wrapper_with_view_session_merger(redis_client) -> None: + """Dedup wrapper must work when the child is a view_session merger.""" + + # Two leaves with overlapping ids; view_session computes a session once. + items_low = dh.make_items("low", 1, 100) + items_high = dh.make_items("high", 1, 100) + + methods_dict, subfeed_low, subfeed_high = dh._build_two_subfeed_methods( + items_low, + items_high, + spec_a=dh._two_subfeed_spec(name="low", subfeed_id="sf_low", dedup_priority=0), + spec_b=dh._two_subfeed_spec(name="high", subfeed_id="sf_high", dedup_priority=100), + ) + + config = dh._dedup_config( + "dedup_vs", + { + "merger_id": "vs", + "type": "merger_view_session", + "session_size": 30, + "session_live_time": 60, + "deduplicate": False, + "shuffle": False, + "data": dh._percentage_config( + "pct", + items=dh._percentage_items(subfeed_low, subfeed_high), + ), + }, + ) + + merger = parse_model(MergerDeduplication, config) + + res_1, res_2 = await dh._run_two_pages( + merger, + methods_dict, + 10, + redis_client=redis_client, + custom_view_session_key="vs1", + ) + + dh._assert_two_pages_no_dupes(res_1, res_2) + + # Deletion priority: for the overlapping early ids, the winning entity must be from high. + _assert_winning_src_for_ids(res_1.data + res_2.data, range(1, 11), "high") + + +@pytest.mark.asyncio +async def test_dedup_in_page_deletion_priority_keeps_high_priority_even_if_config_order_is_low_first() -> None: + """High dedup_priority source must not be deleted even if called later in config order. + + We use a percentage merger where both branches have overlapping ids. + The "high" branch is second in config, but has higher dedup_priority. + """ + + low_items = dh.make_items("low", 1, 200) + high_items = dh.make_items("high", 1, 200) + + config, methods_dict, _, _ = dh._build_two_subfeed_dedup_merger( + items_a=low_items, + items_b=high_items, + merger_id="dedup_priority", + child_builder=lambda sf_a, sf_b: dh._percentage_config( + "pct", + items=dh._percentage_items(sf_a, sf_b), + ), + spec_a=dh._two_subfeed_spec(name="low", subfeed_id="sf_low", dedup_priority=0), + spec_b=dh._two_subfeed_spec(name="high", subfeed_id="sf_high", dedup_priority=100), + ) + merger = parse_model(MergerDeduplication, config) + res = await merger.get_data( + methods_dict=methods_dict, + user_id="u", + limit=10, + next_page=FeedResultNextPage(data={}), + ) + + dh._assert_no_dupes_in_page(res.data) + # Priority is about which source "wins" for a given dedup_key, not about output order. + # With 50/50 limits, the high-priority branch should supply ids 1..5, while the low-priority + # branch will be advanced to avoid duplicates. + _assert_winning_src_for_ids(res.data, range(1, 6), "high") diff --git a/tests/test_merger_percentage.py b/tests/test_merger_percentage.py index e5ab76e..328dc39 100644 --- a/tests/test_merger_percentage.py +++ b/tests/test_merger_percentage.py @@ -1,8 +1,11 @@ +import copy + import pytest from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerPercentage from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_PERCENTAGE_CONFIG +from tests.utils import parse_model @pytest.mark.asyncio @@ -11,7 +14,7 @@ async def test_merger_percentage() -> None: Тест для проверки получения данных из процентного мерджера. """ - merger_percentage = MergerPercentage.parse_obj(MERGER_PERCENTAGE_CONFIG) + merger_percentage = parse_model(MergerPercentage, MERGER_PERCENTAGE_CONFIG) merger_percentage_res = await merger_percentage.get_data( methods_dict=METHODS_DICT, limit=10, @@ -25,3 +28,45 @@ async def test_merger_percentage() -> None: ) assert merger_percentage_res.data == ["x_4", "x_21", "x_22", "x_5", "x_23", "x_24", "x_6", "x_25", "x_26", "x_7"] + + +@pytest.mark.asyncio +async def test_merger_percentage_when_one_leaf_is_empty() -> None: + config = copy.deepcopy(MERGER_PERCENTAGE_CONFIG) + # Make the second leaf return no data + has_next_page=False. + config["items"][1]["data"]["method_name"] = "empty" + + merger_percentage = parse_model(MergerPercentage, config) + res = await merger_percentage.get_data( + methods_dict=METHODS_DICT, + limit=10, + next_page=FeedResultNextPage( + data={ + "subfeed_merger_percentage_example": FeedResultNextPageInside(page=2, after="x_3"), + "subfeed_2_merger_percentage_example": FeedResultNextPageInside(page=3, after="x_20"), + } + ), + user_id="x", + ) + + assert res.data == ["x_4", "x_5", "x_6", "x_7"] + # The non-empty leaf still reports more pages. + assert res.has_next_page is True + + +@pytest.mark.asyncio +async def test_merger_percentage_odd_limit_fills_page_when_sources_have_data() -> None: + merger_percentage = parse_model(MergerPercentage, MERGER_PERCENTAGE_CONFIG) + res = await merger_percentage.get_data( + methods_dict=METHODS_DICT, + limit=11, + next_page=FeedResultNextPage( + data={ + "subfeed_merger_percentage_example": FeedResultNextPageInside(page=2, after="x_3"), + "subfeed_2_merger_percentage_example": FeedResultNextPageInside(page=3, after="x_20"), + } + ), + user_id="x", + ) + + assert len(res.data) == 11 diff --git a/tests/test_merger_percentage_gradient.py b/tests/test_merger_percentage_gradient.py index eaaba9c..e2c6607 100644 --- a/tests/test_merger_percentage_gradient.py +++ b/tests/test_merger_percentage_gradient.py @@ -3,6 +3,7 @@ from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerPercentageGradient from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_PERCENTAGE_GRADIENT_CONFIG +from tests.utils import parse_model @pytest.mark.asyncio @@ -11,7 +12,7 @@ async def test_merger_percentage_gradient() -> None: Тест для проверки получения данных из процентного мерджера с градиентом. """ - merger_percentage_gradient = MergerPercentageGradient.parse_obj(MERGER_PERCENTAGE_GRADIENT_CONFIG) + merger_percentage_gradient = parse_model(MergerPercentageGradient, MERGER_PERCENTAGE_GRADIENT_CONFIG) merger_percentage_gradient_res = await merger_percentage_gradient.get_data( methods_dict=METHODS_DICT, limit=10, @@ -44,7 +45,7 @@ async def test_merger_percentage_gradient_next_page() -> None: Тест для проверки получения данных из процентного мерджера с градиентом после изменения процента на другой странице. """ - merger_percentage_gradient = MergerPercentageGradient.parse_obj(MERGER_PERCENTAGE_GRADIENT_CONFIG) + merger_percentage_gradient = parse_model(MergerPercentageGradient, MERGER_PERCENTAGE_GRADIENT_CONFIG) merger_percentage_gradient_res = await merger_percentage_gradient.get_data( methods_dict=METHODS_DICT, limit=10, @@ -70,3 +71,21 @@ async def test_merger_percentage_gradient_next_page() -> None: "x_22", "x_23", ] + + +@pytest.mark.asyncio +async def test_merger_percentage_gradient_odd_limit_fills_page_when_sources_have_data() -> None: + merger_percentage_gradient = parse_model(MergerPercentageGradient, MERGER_PERCENTAGE_GRADIENT_CONFIG) + res = await merger_percentage_gradient.get_data( + methods_dict=METHODS_DICT, + limit=11, + next_page=FeedResultNextPage( + data={ + "subfeed_from_merger_percentage_gradient_example": FeedResultNextPageInside(page=2, after="x_3"), + "subfeed_to_merger_percentage_gradient_example": FeedResultNextPageInside(page=3, after="x_20"), + } + ), + user_id="x", + ) + + assert len(res.data) == 11 diff --git a/tests/test_merger_positional.py b/tests/test_merger_positional.py index c0f3815..370f770 100644 --- a/tests/test_merger_positional.py +++ b/tests/test_merger_positional.py @@ -3,6 +3,7 @@ from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerPositional from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_POSITIONAL_CONFIG +from tests.utils import parse_model @pytest.mark.asyncio @@ -11,7 +12,7 @@ async def test_merger_positional_with_positions() -> None: Тест для проверки получения данных из позиционного мерджера на основе позиций в конфигурации. """ - merger_positional = MergerPositional.parse_obj(MERGER_POSITIONAL_CONFIG) + merger_positional = parse_model(MergerPositional, MERGER_POSITIONAL_CONFIG) merger_positional_res = await merger_positional.get_data( methods_dict=METHODS_DICT, limit=9, @@ -33,7 +34,7 @@ async def test_merger_positional_with_step() -> None: Тест для проверки получения данных из позиционного мерджера на основе шагов в конфигурации. """ - merger_positional = MergerPositional.parse_obj(MERGER_POSITIONAL_CONFIG) + merger_positional = parse_model(MergerPositional, MERGER_POSITIONAL_CONFIG) merger_positional_res = await merger_positional.get_data( methods_dict=METHODS_DICT, limit=10, @@ -56,7 +57,7 @@ async def test_merger_positional_with_empty_default() -> None: Тест для проверки получения данных из позиционного мерджера на основе шагов в конфигурации. """ - merger_positional = MergerPositional.parse_obj(MERGER_POSITIONAL_CONFIG) + merger_positional = parse_model(MergerPositional, MERGER_POSITIONAL_CONFIG) merger_positional.default.method_name = "empty" merger_positional_res = await merger_positional.get_data( methods_dict=METHODS_DICT, diff --git a/tests/test_merger_view_session.py b/tests/test_merger_view_session.py index 78a0566..f62a096 100644 --- a/tests/test_merger_view_session.py +++ b/tests/test_merger_view_session.py @@ -1,12 +1,17 @@ -import inspect import json import pytest +from smartfeed.feed_models import _redis_call from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerViewSession from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_VIEW_SESSION_CONFIG, MERGER_VIEW_SESSION_DUPS_CONFIG from tests.fixtures.redis import redis_client +from tests.utils import parse_model + + +async def _get_cache_json(redis_client, key: str): + return json.loads(await _redis_call(redis_client, "get", key)) @pytest.mark.asyncio @@ -15,7 +20,7 @@ async def test_merger_view_session_no_redis() -> None: Тест для проверки получения данных из мерджера с кэшированием без клиента Redis. """ - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_CONFIG) + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_CONFIG) with pytest.raises(ValueError): await merger_vs.get_data( methods_dict=METHODS_DICT, @@ -32,7 +37,7 @@ async def test_merger_view_session(redis_client) -> None: Тест для проверки получения данных из мерджера с кэшированием. """ - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_CONFIG) + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_CONFIG) merger_vs_res = await merger_vs.get_data( methods_dict=METHODS_DICT, limit=10, @@ -40,12 +45,7 @@ async def test_merger_view_session(redis_client) -> None: user_id="x", redis_client=redis_client, ) - merger_vs_cache = redis_client.get(name="merger_view_session_example_x") - # Для использования синхронной и асинхронной фикстуры в одном тесте проверяем метод get - if inspect.iscoroutine(merger_vs_cache): - merger_vs_cache = json.loads(await merger_vs_cache) - else: - merger_vs_cache = json.loads(merger_vs_cache) + merger_vs_cache = await _get_cache_json(redis_client, "merger_view_session_example_x") assert merger_vs_res.data == ["x_1", "x_2", "x_3", "x_4", "x_5", "x_6", "x_7", "x_8", "x_9", "x_10"] assert len(merger_vs_cache) == merger_vs.session_size @@ -59,7 +59,7 @@ async def test_merger_view_session_custom_key(redis_client) -> None: Тест для проверки получения данных из мерджера с кэшированием по ключу с кастомным постфиксом. """ - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_CONFIG) + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_CONFIG) # Даем дополнительный параметр, который мерджер добавит в ключ кэша. merger_vs_res = await merger_vs.get_data( methods_dict=METHODS_DICT, @@ -69,12 +69,7 @@ async def test_merger_view_session_custom_key(redis_client) -> None: redis_client=redis_client, custom_view_session_key="foo", ) - merger_vs_cache = redis_client.get(name="merger_view_session_example_x_foo") - # Для использования синхронной и асинхронной фикстуры в одном тесте проверяем метод get - if inspect.iscoroutine(merger_vs_cache): - merger_vs_cache = json.loads(await merger_vs_cache) - else: - merger_vs_cache = json.loads(merger_vs_cache) + merger_vs_cache = await _get_cache_json(redis_client, "merger_view_session_example_x_foo") assert merger_vs_res.data == ["x_1", "x_2", "x_3", "x_4", "x_5", "x_6", "x_7", "x_8", "x_9", "x_10"] assert len(merger_vs_cache) == merger_vs.session_size @@ -88,7 +83,7 @@ async def test_merger_view_session_next_page(redis_client) -> None: Тест для проверки получения данных следующей страницы из мерджера с кэшированием. """ - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_CONFIG) + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_CONFIG) merger_vs_res = await merger_vs.get_data( methods_dict=METHODS_DICT, limit=10, @@ -98,12 +93,7 @@ async def test_merger_view_session_next_page(redis_client) -> None: user_id="x", redis_client=redis_client, ) - merger_vs_cache = redis_client.get(name="merger_view_session_example_x") - # Для использования синхронной и асинхронной фикстуры в одном тесте проверяем метод get - if inspect.iscoroutine(merger_vs_cache): - merger_vs_cache = json.loads(await merger_vs_cache) - else: - merger_vs_cache = json.loads(merger_vs_cache) + merger_vs_cache = await _get_cache_json(redis_client, "merger_view_session_example_x") assert merger_vs_res.data == ["x_11", "x_12", "x_13", "x_14", "x_15", "x_16", "x_17", "x_18", "x_19", "x_20"] assert len(merger_vs_cache) == merger_vs.session_size @@ -113,7 +103,7 @@ async def test_merger_view_session_next_page(redis_client) -> None: @pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) @pytest.mark.asyncio async def test_merger_view_session_deduplication(redis_client) -> None: - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_DUPS_CONFIG) + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_DUPS_CONFIG) merger_vs_res = await merger_vs.get_data( methods_dict=METHODS_DICT, limit=10, @@ -121,12 +111,7 @@ async def test_merger_view_session_deduplication(redis_client) -> None: user_id="x", redis_client=redis_client, ) - merger_vs_cache = redis_client.get(name="merger_view_session_example_x") - # Для использования синхронной и асинхронной фикстуры в одном тесте проверяем метод get - if inspect.iscoroutine(merger_vs_cache): - merger_vs_cache = json.loads(await merger_vs_cache) - else: - merger_vs_cache = json.loads(merger_vs_cache) + merger_vs_cache = await _get_cache_json(redis_client, "merger_view_session_example_x") assert merger_vs_res.data == [i for i in range(1, 11)] assert len(merger_vs_cache) == merger_vs.session_size diff --git a/tests/test_parsing_config.py b/tests/test_parsing_config.py index 4d789cb..bf61007 100644 --- a/tests/test_parsing_config.py +++ b/tests/test_parsing_config.py @@ -4,6 +4,7 @@ from smartfeed.schemas import ( FeedConfig, MergerAppend, + MergerDeduplication, MergerPercentage, MergerPercentageGradient, MergerPercentageItem, @@ -11,7 +12,7 @@ MergerViewSession, SubFeed, ) -from tests.fixtures.configs import METHODS_DICT, PARSING_CONFIG_FIXTURE +from tests.fixtures.configs import METHODS_DICT, PARSING_CONFIG_FIXTURE, PARSING_DEDUP_CONFIG_FIXTURE @pytest.mark.asyncio @@ -45,3 +46,13 @@ async def test_parsing_config() -> None: # SubFeed with Raise Exception False. assert isinstance(feed_manager.feed_config.feed.default.items[0].data, SubFeed) assert feed_manager.feed_config.feed.default.items[0].data.raise_error is False + + +@pytest.mark.asyncio +async def test_parsing_config_deduplication_merger() -> None: + feed_manager = FeedManager(config=PARSING_DEDUP_CONFIG_FIXTURE, methods_dict=METHODS_DICT) + + assert isinstance(feed_manager.feed_config, FeedConfig) + assert isinstance(feed_manager.feed_config.feed, MergerDeduplication) + # Deduplication merger is a wrapper around a single child feed. + assert isinstance(feed_manager.feed_config.feed.data, (MergerPercentage, SubFeed)) diff --git a/tests/test_redis_live.py b/tests/test_redis_live.py index 1a23839..e1dc13d 100644 --- a/tests/test_redis_live.py +++ b/tests/test_redis_live.py @@ -1,34 +1,37 @@ import asyncio import json +import time + import pytest import redis from redis.asyncio import Redis as AsyncRedis -import time from smartfeed.schemas import FeedResultNextPage, MergerViewSession from tests.fixtures.configs import METHODS_DICT from tests.fixtures.mergers import MERGER_VIEW_SESSION_CONFIG +from tests.utils import parse_model class RedisReplicationSimulator: """ Симулятор задержки репликации Redis для тестирования проблемы кластера. """ + def __init__(self, real_client): self.real_client = real_client self.write_delay = 0.1 # Задержка для имитации репликации self.pending_writes = {} # Ключи которые только что записали - + def exists(self, cache_key): return self.real_client.exists(cache_key) - + def set(self, name, value, ex=None): # Записываем в реальный Redis result = self.real_client.set(name, value, ex=ex) # Помечаем что этот ключ только что записан (имитация репликации) self.pending_writes[name] = time.time() return result - + def get(self, name): # Если ключ только что записан (в течение write_delay секунд), возвращаем None if name in self.pending_writes: @@ -38,7 +41,7 @@ def get(self, name): else: # Задержка прошла, можно удалить из pending del self.pending_writes[name] - + # Обычное чтение из Redis return self.real_client.get(name) @@ -49,24 +52,24 @@ async def test_redis_replication_delay_problem(): Тест для воспроизведения проблемы репликации Redis с использованием RedisReplicationSimulator для имитации задержки. """ - + # Подключаемся к Redis (должен быть запущен локально) try: - real_client = redis.Redis(host='localhost', port=6379, db=0) + real_client = redis.Redis(host="localhost", port=6379, db=0) real_client.ping() # Проверяем соединение except (redis.ConnectionError, redis.ResponseError): pytest.skip("Redis not available for live testing") - + # Очищаем тестовый ключ test_key = "test_merger_view_session_test_user" real_client.delete(test_key) - + # Используем симулятор задержки репликации redis_client = RedisReplicationSimulator(real_client) - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_CONFIG) - + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_CONFIG) + print("\n=== Демонстрация проблемы с задержкой репликации ===") - + try: # Этот вызов должен воспроизвести проблему с оригинальным кодом result = await merger_vs.get_data( @@ -76,17 +79,17 @@ async def test_redis_replication_delay_problem(): user_id="test_user", redis_client=redis_client, ) - + print("✅ Исправление работает! Получили результат без ошибки:") print(f" Данные: {result.data[:5]}... (показаны первые 5)") print(f" Размер: {len(result.data)}") print(f" Есть следующая страница: {result.has_next_page}") - + # Проверяем что получили валидные данные assert len(result.data) == 10 assert result.data[0] == "test_user_1" assert result.has_next_page is True - + except TypeError as e: if "the JSON object must be str, bytes or bytearray, not NoneType" in str(e): print("❌ Проблема НЕ исправлена! Все еще получаем TypeError") @@ -94,33 +97,33 @@ async def test_redis_replication_delay_problem(): else: print(f"❓ Неожиданная ошибка: {e}") raise - + finally: # Очистка real_client.delete(test_key) real_client.close() -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_redis_multiple_requests(): """ Тест множественных запросов для проверки стабильности исправления. """ - + try: - real_client = redis.Redis(host='localhost', port=6379, db=0) + real_client = redis.Redis(host="localhost", port=6379, db=0) real_client.ping() except (redis.ConnectionError, redis.ResponseError): pytest.skip("Redis not available for live testing") - + test_key = "test_merger_multiple_test_user" real_client.delete(test_key) - + redis_client = RedisReplicationSimulator(real_client) - merger_vs = MergerViewSession.parse_obj(MERGER_VIEW_SESSION_CONFIG) - + merger_vs = parse_model(MergerViewSession, MERGER_VIEW_SESSION_CONFIG) + print("\n=== Тест множественных запросов ===") - + try: # Первый запрос - создает кэш result1 = await merger_vs.get_data( @@ -130,33 +133,34 @@ async def test_redis_multiple_requests(): user_id="test_user", redis_client=redis_client, ) - + print(f"Первый запрос: получили {len(result1.data)} элементов") - + # Ждем чтобы задержка репликации прошла await asyncio.sleep(0.2) - - # Второй запрос - должен использовать кэш + + # Второй запрос - должен использовать кэш from smartfeed.schemas import FeedResultNextPageInside + result2 = await merger_vs.get_data( methods_dict=METHODS_DICT, - limit=5, + limit=5, next_page=FeedResultNextPage( data={"merger_view_session_example": FeedResultNextPageInside(page=2, after=None)} ), user_id="test_user", redis_client=redis_client, ) - + print(f"Второй запрос: получили {len(result2.data)} элементов") print(f"Данные второй страницы: {result2.data}") - + # Проверяем что получили разные данные (пагинация работает) assert result1.data != result2.data assert len(result2.data) == 5 - + print("✅ Множественные запросы работают корректно!") - + finally: real_client.delete(test_key) real_client.close() @@ -164,4 +168,4 @@ async def test_redis_multiple_requests(): if __name__ == "__main__": # Для запуска напрямую без pytest - asyncio.run(test_redis_replication_delay_problem()) \ No newline at end of file + asyncio.run(test_redis_replication_delay_problem()) diff --git a/tests/test_seen_store_unit.py b/tests/test_seen_store_unit.py new file mode 100644 index 0000000..fc7f6e7 --- /dev/null +++ b/tests/test_seen_store_unit.py @@ -0,0 +1,65 @@ +from typing import Any + +import pytest + +from smartfeed.feed_models import _redis_call +from smartfeed.policies.dedup_utils import decode_seen_from_cursor +from smartfeed.policies.seen_store import CursorSeenStore, RedisSeenStore +from tests.fixtures.redis import redis_client + + +@pytest.mark.asyncio +async def test_cursor_seen_store_set_max_and_commit_roundtrip() -> None: + store = CursorSeenStore.from_after(after=None, cursor_compress=True, cursor_max_keys=None) + store.set_max("a", 1) + store.set_max("a", 1) # no-op + store.set_max("a", 0) # no-op (lower) + store.set_max("b", 2) + + after = await store.commit() + decoded = decode_seen_from_cursor(after) + assert decoded == {"a": 1, "b": 2} + + +@pytest.mark.asyncio +async def test_cursor_seen_store_commit_keeps_previous_cursor_state() -> None: + store = CursorSeenStore.from_after( + after={"v": 2, "seen": [["a", 1], ["b", 2]]}, + cursor_compress=False, + cursor_max_keys=None, + ) + store.set_max("c", 3) + + after = await store.commit() + decoded = decode_seen_from_cursor(after) + assert decoded == {"a": 1, "b": 2, "c": 3} + + +@pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) +@pytest.mark.asyncio +async def test_redis_seen_store_prefetch_set_max_commit_and_reset(redis_client) -> None: + key = "test_seen_store" + await _redis_call(redis_client, "delete", key) + # Pre-seed zset state + await _redis_call(redis_client, "zadd", key, mapping={"a": 5.0}) + + store = RedisSeenStore.create(redis_client=redis_client, redis_key=key, ttl_seconds=60) + + await store.prefetch(["a", "a", "b"]) # duplicates + assert store.get("a") == 5 + assert store.get("b") is None + + store.set_max("a", 3) # should not reduce existing + store.set_max("b", 2) + + await store.commit() + + # New state should be present in redis + scores = list(await _redis_call(redis_client, "zmscore", key, ["a", "b"])) + assert scores == [5.0, 2.0] + + await store.reset() + scores_after_reset = list(await _redis_call(redis_client, "zmscore", key, ["a", "b"])) + assert scores_after_reset == [None, None] + + await _redis_call(redis_client, "delete", key) diff --git a/tests/test_sub_feed.py b/tests/test_sub_feed.py index 7da1924..11645d4 100644 --- a/tests/test_sub_feed.py +++ b/tests/test_sub_feed.py @@ -16,7 +16,7 @@ async def test_sub_feed() -> None: Тест для проверки получения данных из субфида (без параметров). """ - sub_feed = SubFeed.parse_obj(SUBFEED_CONFIG) + sub_feed = SubFeed.model_validate(SUBFEED_CONFIG) sub_feed_data = await sub_feed.get_data( methods_dict=METHODS_DICT, limit=15, @@ -33,7 +33,7 @@ async def test_sub_feed_with_params() -> None: Тест для проверки получения данных из субфида (с параметрами). """ - sub_feed = SubFeed.parse_obj(SUBFEED_WITH_PARAMS_CONFIG) + sub_feed = SubFeed.model_validate(SUBFEED_WITH_PARAMS_CONFIG) sub_feed_data = await sub_feed.get_data( methods_dict=METHODS_DICT, limit=15, @@ -50,7 +50,7 @@ async def test_sub_feed_raise_error() -> None: Тест для проверки получения данных из субфида (без параметров). """ - sub_feed = SubFeed.parse_obj(SUBFEED_CONFIG_RAISE_ERROR) + sub_feed = SubFeed.model_validate(SUBFEED_CONFIG_RAISE_ERROR) with pytest.raises(Exception): await sub_feed.get_data( @@ -67,7 +67,7 @@ async def test_sub_feed_no_raise_error() -> None: Тест для проверки получения данных из субфида (без параметров). """ - sub_feed = SubFeed.parse_obj(SUBFEED_CONFIG_NO_RAISE_ERROR) + sub_feed = SubFeed.model_validate(SUBFEED_CONFIG_NO_RAISE_ERROR) sub_feed_data = await sub_feed.get_data( methods_dict=METHODS_DICT, limit=15, diff --git a/tests/test_view_session_unit.py b/tests/test_view_session_unit.py new file mode 100644 index 0000000..a847eea --- /dev/null +++ b/tests/test_view_session_unit.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import Any + +import pytest + +from smartfeed.feed_models import _redis_call +from smartfeed.schemas import FeedResultNextPage, FeedResultNextPageInside, MergerViewSession +from tests.fixtures.configs import METHODS_DICT +from tests.fixtures.mergers import MERGER_VIEW_SESSION_CONFIG +from tests.fixtures.redis import redis_client +from tests.utils import parse_model + + +@dataclass +class _ItemWithAttr: + id: str + + +def test_get_dedup_key_supports_dict_and_attr_and_raises_on_missing() -> None: + cfg = dict(MERGER_VIEW_SESSION_CONFIG) + cfg.update({"deduplicate": True, "dedup_key": "id"}) + merger = parse_model(MergerViewSession, cfg) + + assert merger._get_dedup_key_or_attr({"id": "x"}) == "x" + assert merger._get_dedup_key_or_attr(_ItemWithAttr(id="y")) == "y" + + with pytest.raises(AssertionError): + merger._get_dedup_key_or_attr({"nope": 1}) + + +@pytest.mark.parametrize("redis_client", ["sync", "async"], indirect=True) +@pytest.mark.asyncio +async def test_view_session_shuffle_applies_to_result(redis_client, monkeypatch) -> None: + import smartfeed.mergers.view_session as vs_mod + + cfg = dict(MERGER_VIEW_SESSION_CONFIG) + cfg.update({"shuffle": True}) + merger = parse_model(MergerViewSession, cfg) + cache_key = f"{merger.merger_id}_x" + await _redis_call(redis_client, "delete", cache_key) + + # Make shuffle deterministic: reverse in-place + monkeypatch.setattr(vs_mod, "shuffle", lambda data: data.reverse()) + + res = await merger.get_data( + methods_dict=METHODS_DICT, + limit=5, + next_page=FeedResultNextPage(data={}), + user_id="x", + redis_client=redis_client, + ) + + assert res.data == ["x_5", "x_4", "x_3", "x_2", "x_1"] + await _redis_call(redis_client, "delete", cache_key) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..af29848 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from smartfeed.pydantic_compat import parse_model + +__all__ = ["parse_model"]