Skip to content

Commit f6e109e

Browse files
grichaclaude
andcommitted
ref(seer): Add tests for _resolve_viewer_context
Address review feedback: add tests covering all branches of the viewer context resolution (contextvar only, explicit only, merge, mismatch warning, token stripping). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ac653dc commit f6e109e

File tree

3 files changed

+153
-38
lines changed

3 files changed

+153
-38
lines changed

src/sentry/seer/signed_seer_api.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from sentry.net.http import connection_from_url
1313
from sentry.utils import metrics
14-
from sentry.viewer_context import get_viewer_context
14+
from sentry.viewer_context import ViewerContext, get_viewer_context
1515

1616

1717
class SeerViewerContext(TypedDict, total=False):
@@ -49,43 +49,65 @@ class SeerViewerContext(TypedDict, total=False):
4949

5050

5151
def _resolve_viewer_context(
52-
explicit: SeerViewerContext | None,
53-
) -> dict[str, Any] | None:
54-
"""Build the viewer context payload for Seer requests.
52+
explicit: SeerViewerContext | None = None,
53+
) -> ViewerContext | None:
54+
"""Merge explicit SeerViewerContext with the contextvar.
5555
56-
Uses the contextvar as the base. If an explicit SeerViewerContext was
57-
passed, its non-None fields win (with a warning on disagreement).
56+
Converts the legacy SeerViewerContext into a ViewerContext, then merges
57+
with the contextvar. Explicit non-None fields win. On disagreement,
58+
logs a warning and strips the token for safety.
5859
"""
5960
vc = get_viewer_context()
60-
if vc is None and explicit is None:
61+
62+
if explicit is None and vc is None:
6163
return None
64+
if explicit is None:
65+
return vc
66+
67+
explicit_vc = ViewerContext(
68+
organization_id=explicit.get("organization_id"),
69+
user_id=explicit.get("user_id"),
70+
)
6271

63-
result: dict[str, Any] = {}
64-
if vc is not None:
65-
if vc.organization_id is not None:
66-
result["organization_id"] = vc.organization_id
67-
if vc.user_id is not None:
68-
result["user_id"] = vc.user_id
69-
result["actor_type"] = vc.actor_type.value
70-
if vc.token is not None:
71-
result["token"] = {"kind": vc.token.kind, "scopes": list(vc.token.get_scopes())}
72-
73-
if explicit:
74-
has_mismatch = False
75-
for field in ("organization_id", "user_id"):
76-
val = explicit.get(field) # type: ignore[literal-required]
77-
if val is not None and field in result and result[field] != val:
78-
logger.warning(
79-
"seer.viewer_context_mismatch",
80-
extra={"field": field, "contextvar": result[field], "explicit": val},
81-
)
82-
has_mismatch = True
83-
if val is not None:
84-
result[field] = val
85-
if has_mismatch:
86-
result.pop("token", None)
87-
88-
return result or None
72+
if vc is None:
73+
return explicit_vc
74+
75+
has_mismatch = False
76+
org_id = vc.organization_id
77+
user_id = vc.user_id
78+
79+
if explicit_vc.organization_id is not None:
80+
if org_id is not None and org_id != explicit_vc.organization_id:
81+
logger.warning(
82+
"seer.viewer_context_mismatch",
83+
extra={
84+
"field": "organization_id",
85+
"contextvar": org_id,
86+
"explicit": explicit_vc.organization_id,
87+
},
88+
)
89+
has_mismatch = True
90+
org_id = explicit_vc.organization_id
91+
92+
if explicit_vc.user_id is not None:
93+
if user_id is not None and user_id != explicit_vc.user_id:
94+
logger.warning(
95+
"seer.viewer_context_mismatch",
96+
extra={
97+
"field": "user_id",
98+
"contextvar": user_id,
99+
"explicit": explicit_vc.user_id,
100+
},
101+
)
102+
has_mismatch = True
103+
user_id = explicit_vc.user_id
104+
105+
return ViewerContext(
106+
organization_id=org_id,
107+
user_id=user_id,
108+
actor_type=vc.actor_type,
109+
token=None if has_mismatch else vc.token,
110+
)
89111

90112

91113
@sentry_sdk.tracing.trace
@@ -113,11 +135,11 @@ def make_signed_seer_api_request(
113135
**auth_headers,
114136
}
115137

116-
resolved_context = _resolve_viewer_context(viewer_context)
117-
if resolved_context:
138+
resolved = _resolve_viewer_context(viewer_context)
139+
if resolved:
118140
if settings.SEER_API_SHARED_SECRET:
119141
try:
120-
context_bytes = orjson.dumps(resolved_context)
142+
context_bytes = orjson.dumps(resolved.serialize())
121143
context_signature = sign_viewer_context(context_bytes)
122144
headers["X-Viewer-Context"] = context_bytes.decode("utf-8")
123145
headers["X-Viewer-Context-Signature"] = context_signature

src/sentry/viewer_context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import dataclasses
66
import enum
77
from collections.abc import Generator
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99

1010
if TYPE_CHECKING:
1111
from sentry.auth.services.auth import AuthenticatedToken
@@ -51,6 +51,17 @@ class ViewerContext:
5151
# NOT propagated across process/service boundaries.
5252
token: AuthenticatedToken | None = None
5353

54+
def serialize(self) -> dict[str, Any]:
55+
"""Serialize to a dict for cross-service headers (e.g. X-Viewer-Context)."""
56+
result: dict[str, Any] = {"actor_type": self.actor_type.value}
57+
if self.organization_id is not None:
58+
result["organization_id"] = self.organization_id
59+
if self.user_id is not None:
60+
result["user_id"] = self.user_id
61+
if self.token is not None:
62+
result["token"] = {"kind": self.token.kind, "scopes": list(self.token.get_scopes())}
63+
return result
64+
5465

5566
@contextlib.contextmanager
5667
def viewer_context_scope(ctx: ViewerContext) -> Generator[None]:

tests/sentry/seer/test_signed_seer_api.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44
from django.test import override_settings
55

6-
from sentry.seer.signed_seer_api import make_signed_seer_api_request
6+
from sentry.auth.services.auth import AuthenticatedToken
7+
from sentry.seer.signed_seer_api import (
8+
SeerViewerContext,
9+
_resolve_viewer_context,
10+
make_signed_seer_api_request,
11+
)
12+
from sentry.viewer_context import ActorType, ViewerContext, viewer_context_scope
713

814
REQUEST_BODY = b'{"b": 12, "thing": "thing"}'
915
PATH = "/v0/some/url"
@@ -111,3 +117,79 @@ def test_times_request(mock_metrics_timer: MagicMock, path: str) -> None:
111117
"endpoint": PATH,
112118
},
113119
)
120+
121+
122+
class TestResolveViewerContext:
123+
def test_both_none(self) -> None:
124+
assert _resolve_viewer_context(None) is None
125+
126+
def test_contextvar_only(self) -> None:
127+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER)
128+
with viewer_context_scope(ctx):
129+
result = _resolve_viewer_context(None)
130+
131+
assert result is not None
132+
assert result.organization_id == 42
133+
assert result.user_id == 7
134+
assert result.actor_type == ActorType.USER
135+
136+
def test_explicit_only(self) -> None:
137+
result = _resolve_viewer_context(SeerViewerContext(organization_id=99, user_id=5))
138+
assert result is not None
139+
assert result.organization_id == 99
140+
assert result.user_id == 5
141+
142+
def test_contextvar_with_token(self) -> None:
143+
token = AuthenticatedToken(
144+
kind="api_token",
145+
scopes=["org:read", "project:write"],
146+
allowed_origins=[],
147+
)
148+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER, token=token)
149+
with viewer_context_scope(ctx):
150+
result = _resolve_viewer_context(None)
151+
152+
assert result is not None
153+
assert result.token is not None
154+
assert result.token.kind == "api_token"
155+
assert set(result.token.get_scopes()) == {"org:read", "project:write"}
156+
157+
def test_explicit_overrides_contextvar(self) -> None:
158+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER)
159+
with viewer_context_scope(ctx):
160+
result = _resolve_viewer_context(SeerViewerContext(organization_id=42, user_id=99))
161+
162+
assert result is not None
163+
assert result.organization_id == 42
164+
assert result.user_id == 99
165+
assert result.actor_type == ActorType.USER
166+
167+
@patch("sentry.seer.signed_seer_api.logger")
168+
def test_mismatch_warns_and_strips_token(self, mock_logger: MagicMock) -> None:
169+
token = AuthenticatedToken(
170+
kind="api_token",
171+
scopes=["org:read"],
172+
allowed_origins=[],
173+
)
174+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER, token=token)
175+
with viewer_context_scope(ctx):
176+
result = _resolve_viewer_context(SeerViewerContext(organization_id=999))
177+
178+
assert result is not None
179+
assert result.organization_id == 999
180+
assert result.token is None
181+
mock_logger.warning.assert_called_once()
182+
assert mock_logger.warning.call_args[0][0] == "seer.viewer_context_mismatch"
183+
184+
def test_no_mismatch_keeps_token(self) -> None:
185+
token = AuthenticatedToken(
186+
kind="api_token",
187+
scopes=["org:read"],
188+
allowed_origins=[],
189+
)
190+
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER, token=token)
191+
with viewer_context_scope(ctx):
192+
result = _resolve_viewer_context(SeerViewerContext(organization_id=42, user_id=7))
193+
194+
assert result is not None
195+
assert result.token is not None

0 commit comments

Comments
 (0)