Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tools/flake8_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@

S016_msg = "S016 Use `from sentry.utils.concurrent import ContextPropagatingThreadPoolExecutor` instead of `concurrent.futures.ThreadPoolExecutor` to ensure contextvars propagation."

S017_msg = (
"S017 Platform boundary violation: do not import non-platform getsentry code in "
"billing/platform/. Use only getsentry.billing.platform.* imports."
)


# --- S015: do not hardcode current or future UTC year as test "now" ---
# Flag year >= current UTC year at lint time. Module/class scope + freeze_time(datetime(...)).
Expand All @@ -57,6 +62,19 @@ def _is_tests_path(filename: str) -> bool:
return "tests/" in filename or "testutils/" in filename


def _is_platform_path(filename: str) -> bool:
Comment thread
noahsmartin marked this conversation as resolved.
return "billing/platform/" in filename and "tests/" not in filename


def _is_non_platform_import(module: str) -> bool:
"""Check if a getsentry import is outside the billing platform."""
if module.startswith("getsentry.") or module == "getsentry":
platform_prefix = "getsentry.billing.platform"
if not (module.startswith(platform_prefix + ".") or module == platform_prefix):
return True
return False


# Returns the literal year when this is a datetime(...) call shape we lint for.
def _wall_clock_year_from_datetime_call(node: ast.Call) -> int | None:
if not node.args:
Expand Down Expand Up @@ -120,6 +138,14 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
):
self.errors.append((node.lineno, node.col_offset, S016_msg))

if (
_is_platform_path(self.filename)
and node.module
and not node.level
and _is_non_platform_import(node.module)
):
self.errors.append((node.lineno, node.col_offset, S017_msg))

self.generic_visit(node)

def visit_Import(self, node: ast.Import) -> None:
Expand All @@ -134,6 +160,11 @@ def visit_Import(self, node: ast.Import) -> None:
):
self.errors.append((node.lineno, node.col_offset, S007_msg))

if _is_platform_path(self.filename):
for alias in node.names:
if _is_non_platform_import(alias.name):
self.errors.append((node.lineno, node.col_offset, S017_msg))

self.generic_visit(node)

def visit_Attribute(self, node: ast.Attribute) -> None:
Expand Down
Loading