Skip to content

Commit 0b25b89

Browse files
grichaclaude
andcommitted
ref(seer): Type return as ResolvedViewerContext and add tests
Address review feedback: use TypedDict for the resolved context wire format instead of dict[str, Any], and add tests covering all branches of _resolve_viewer_context. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ac653dc commit 0b25b89

File tree

2 files changed

+102
-5
lines changed

2 files changed

+102
-5
lines changed

src/sentry/seer/signed_seer_api.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ class SeerViewerContext(TypedDict, total=False):
2424
user_id: int | None
2525

2626

27+
class _SeerTokenPayload(TypedDict):
28+
kind: str
29+
scopes: list[str]
30+
31+
32+
class ResolvedViewerContext(TypedDict, total=False):
33+
"""Wire format for the X-Viewer-Context header sent to Seer."""
34+
35+
organization_id: int
36+
user_id: int | None
37+
actor_type: str
38+
token: _SeerTokenPayload
39+
40+
2741
logger = logging.getLogger(__name__)
2842

2943

@@ -50,7 +64,7 @@ class SeerViewerContext(TypedDict, total=False):
5064

5165
def _resolve_viewer_context(
5266
explicit: SeerViewerContext | None,
53-
) -> dict[str, Any] | None:
67+
) -> ResolvedViewerContext | None:
5468
"""Build the viewer context payload for Seer requests.
5569
5670
Uses the contextvar as the base. If an explicit SeerViewerContext was
@@ -60,15 +74,17 @@ def _resolve_viewer_context(
6074
if vc is None and explicit is None:
6175
return None
6276

63-
result: dict[str, Any] = {}
77+
result = ResolvedViewerContext()
6478
if vc is not None:
6579
if vc.organization_id is not None:
6680
result["organization_id"] = vc.organization_id
6781
if vc.user_id is not None:
6882
result["user_id"] = vc.user_id
6983
result["actor_type"] = vc.actor_type.value
7084
if vc.token is not None:
71-
result["token"] = {"kind": vc.token.kind, "scopes": list(vc.token.get_scopes())}
85+
result["token"] = _SeerTokenPayload(
86+
kind=vc.token.kind, scopes=list(vc.token.get_scopes())
87+
)
7288

7389
if explicit:
7490
has_mismatch = False
@@ -81,7 +97,7 @@ def _resolve_viewer_context(
8197
)
8298
has_mismatch = True
8399
if val is not None:
84-
result[field] = val
100+
result[field] = val # type: ignore[literal-required]
85101
if has_mismatch:
86102
result.pop("token", None)
87103

tests/sentry/seer/test_signed_seer_api.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44
from django.test import override_settings
55

6-
from sentry.seer.signed_seer_api import make_signed_seer_api_request
6+
from sentry.auth.services.auth import AuthenticatedToken
7+
from sentry.seer.signed_seer_api import (
8+
SeerViewerContext,
9+
_resolve_viewer_context,
10+
make_signed_seer_api_request,
11+
)
12+
from sentry.viewer_context import ActorType, ViewerContext, viewer_context_scope
713

814
REQUEST_BODY = b'{"b": 12, "thing": "thing"}'
915
PATH = "/v0/some/url"
@@ -111,3 +117,78 @@ def test_times_request(mock_metrics_timer: MagicMock, path: str) -> None:
111117
"endpoint": PATH,
112118
},
113119
)
120+
121+
122+
class TestResolveViewerContext:
123+
def test_both_none(self) -> None:
124+
assert _resolve_viewer_context(None) is None
125+
126+
def test_contextvar_only(self) -> None:
127+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER)
128+
with viewer_context_scope(ctx):
129+
result = _resolve_viewer_context(None)
130+
131+
assert result is not None
132+
assert result["organization_id"] == 42
133+
assert result["user_id"] == 7
134+
assert result["actor_type"] == "user"
135+
136+
def test_explicit_only(self) -> None:
137+
result = _resolve_viewer_context(SeerViewerContext(organization_id=99, user_id=5))
138+
assert result is not None
139+
assert result["organization_id"] == 99
140+
assert result["user_id"] == 5
141+
142+
def test_contextvar_with_token(self) -> None:
143+
token = AuthenticatedToken(
144+
kind="api_token",
145+
scopes=["org:read", "project:write"],
146+
allowed_origins=[],
147+
)
148+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER, token=token)
149+
with viewer_context_scope(ctx):
150+
result = _resolve_viewer_context(None)
151+
152+
assert result is not None
153+
assert result["token"]["kind"] == "api_token"
154+
assert set(result["token"]["scopes"]) == {"org:read", "project:write"}
155+
156+
def test_explicit_overrides_contextvar(self) -> None:
157+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER)
158+
with viewer_context_scope(ctx):
159+
result = _resolve_viewer_context(SeerViewerContext(organization_id=42, user_id=99))
160+
161+
assert result is not None
162+
assert result["organization_id"] == 42
163+
assert result["user_id"] == 99
164+
assert result["actor_type"] == "user"
165+
166+
@patch("sentry.seer.signed_seer_api.logger")
167+
def test_mismatch_warns_and_strips_token(self, mock_logger: MagicMock) -> None:
168+
token = AuthenticatedToken(
169+
kind="api_token",
170+
scopes=["org:read"],
171+
allowed_origins=[],
172+
)
173+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER, token=token)
174+
with viewer_context_scope(ctx):
175+
result = _resolve_viewer_context(SeerViewerContext(organization_id=999))
176+
177+
assert result is not None
178+
assert result["organization_id"] == 999
179+
assert "token" not in result
180+
mock_logger.warning.assert_called_once()
181+
assert mock_logger.warning.call_args[0][0] == "seer.viewer_context_mismatch"
182+
183+
def test_no_mismatch_keeps_token(self) -> None:
184+
token = AuthenticatedToken(
185+
kind="api_token",
186+
scopes=["org:read"],
187+
allowed_origins=[],
188+
)
189+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER, token=token)
190+
with viewer_context_scope(ctx):
191+
result = _resolve_viewer_context(SeerViewerContext(organization_id=42, user_id=7))
192+
193+
assert result is not None
194+
assert "token" in result

0 commit comments

Comments
 (0)