|
5 | 5 | T = TypeVar("T") |
6 | 6 |
|
7 | 7 |
|
| 8 | +class SortField(BaseModel): |
| 9 | + field: str |
| 10 | + order: str = "asc" |
| 11 | + |
| 12 | + |
8 | 13 | class BaseFilter(BaseModel): |
9 | 14 | page: int = Field(1, ge=1) |
10 | 15 | size: int = Field(20, ge=1, le=100) |
11 | | - sort_by: str = "id" |
12 | | - sort_order: str = Field("asc", pattern="^(asc|desc)$") |
| 16 | + sorts: list[SortField] = Field(default=[SortField(field="id")]) |
13 | 17 |
|
14 | 18 | search_fields: list[str] = Field(default=[], exclude=True) |
15 | 19 |
|
16 | 20 | def filters(self) -> dict: |
17 | | - excluded = {"page", "size", "sort_by", "sort_order", "search_fields"} |
| 21 | + excluded = {"page", "size", "sorts", "search_fields"} |
18 | 22 | pairs = self.model_dump(exclude=excluded, exclude_none=True) |
19 | 23 | return pairs |
20 | 24 |
|
@@ -52,10 +56,14 @@ def _valid_columns(rows: list[dict]) -> set[str]: |
52 | 56 |
|
53 | 57 |
|
54 | 58 | def _sort_rows(rows: list[dict], filt: BaseFilter) -> list[dict]: |
55 | | - if filt.sort_by not in _valid_columns(rows): |
| 59 | + valid = _valid_columns(rows) |
| 60 | + safe = [s for s in filt.sorts if s.field in valid] |
| 61 | + if not safe: |
56 | 62 | return rows |
57 | | - reverse = filt.sort_order == "desc" |
58 | | - return sorted(rows, key=lambda r: r.get(filt.sort_by, ""), reverse=reverse) |
| 63 | + for s in reversed(safe): |
| 64 | + reverse = s.order == "desc" |
| 65 | + rows = sorted(rows, key=lambda r, f=s.field: r.get(f, ""), reverse=reverse) |
| 66 | + return rows |
59 | 67 |
|
60 | 68 |
|
61 | 69 | def apply_filter(rows: list[dict], filt: BaseFilter) -> PagedResponse: |
|
0 commit comments