diff --git a/ctms/auth.py b/ctms/auth.py index 919859f9..0e5902f0 100644 --- a/ctms/auth.py +++ b/ctms/auth.py @@ -7,10 +7,8 @@ the client POSTs to /token again. """ -import warnings from contextvars import ContextVar -from datetime import datetime, timedelta, timezone -from typing import Dict, Optional +from datetime import UTC, datetime, timedelta import argon2 import jwt @@ -24,7 +22,8 @@ pwd_context = argon2.PasswordHasher() -auth_info_context: ContextVar[dict] = ContextVar("auth_info_context", default={}) +auth_info_context: ContextVar[dict] = ContextVar("auth_info_context") +auth_info_context.set({}) def verify_password(plain_password, hashed_password) -> bool: @@ -42,11 +41,11 @@ def create_access_token( data: dict, expires_delta: timedelta, secret_key: str, - now: Optional[datetime] = None, + now: datetime | None = None, ) -> str: """Create a JWT string to act as an OAuth2 access token.""" to_encode = data.copy() - expire = (now or datetime.now(timezone.utc)) + expires_delta + expire = (now or datetime.now(UTC)) + expires_delta to_encode["exp"] = expire encoded_jwt: str = jwt.encode(to_encode, secret_key, algorithm="HS256") return encoded_jwt @@ -97,8 +96,8 @@ def __init__( self, grant_type: str = Form(None, pattern="^(client_credentials|refresh_token)$"), scope: str = Form(""), - client_id: Optional[str] = Form(None), - client_secret: Optional[str] = Form(None), + client_id: str | None = Form(None), + client_secret: str | None = Form(None), ): self.grant_type = grant_type self.scopes = scope.split() @@ -121,18 +120,16 @@ class OAuth2ClientCredentials(OAuth2): def __init__( self, tokenUrl: str, - scheme_name: Optional[str] = None, - scopes: Optional[Dict[str, str]] = None, + scheme_name: str | None = None, + scopes: dict[str, str] | None = None, ): if not scopes: scopes = {} - flows = OAuthFlowsModel( - clientCredentials={"tokenUrl": tokenUrl, "scopes": scopes} - ) + flows = OAuthFlowsModel(clientCredentials={"tokenUrl": tokenUrl, "scopes": scopes}) super().__init__(flows=flows, scheme_name=scheme_name, auto_error=True) - async def __call__(self, request: Request) -> Optional[str]: - authorization: Optional[str] = request.headers.get("Authorization") + async def __call__(self, request: Request) -> str | None: + authorization: str | None = request.headers.get("Authorization") # TODO: Try combining these lines after FastAPI 0.61.2 / mypy update scheme_param = get_authorization_scheme_param(authorization) diff --git a/ctms/bin/client_credentials.py b/ctms/bin/client_credentials.py index 7bb13949..f1d3dbab 100755 --- a/ctms/bin/client_credentials.py +++ b/ctms/bin/client_credentials.py @@ -81,9 +81,7 @@ def print_new_credentials( """ ) else: - print( - "These credentials are currently disabled, and can not be used to get an OAuth2 access token." - ) + print("These credentials are currently disabled, and can not be used to get an OAuth2 access token.") def main(db, settings, test_args=None): # noqa: PLR0912 @@ -99,15 +97,9 @@ def main(db, settings, test_args=None): # noqa: PLR0912 parser = argparse.ArgumentParser(description="Create or update client credentials.") parser.add_argument("name", help="short name of the client") parser.add_argument("-e", "--email", help="contact email for the client") - parser.add_argument( - "--enable", action="store_true", help="enable a disabled client" - ) - parser.add_argument( - "--disable", action="store_true", help="disable a new or enabled client" - ) - parser.add_argument( - "--rotate-secret", action="store_true", help="generate a new secret key" - ) + parser.add_argument("--enable", action="store_true", help="enable a disabled client") + parser.add_argument("--disable", action="store_true", help="disable a new or enabled client") + parser.add_argument("--rotate-secret", action="store_true", help="generate a new secret key") args = parser.parse_args(args=test_args) name = args.name @@ -117,9 +109,7 @@ def main(db, settings, test_args=None): # noqa: PLR0912 rotate = args.rotate_secret if not re.match(r"^[-_.a-zA-Z0-9]*$", name): - print( - f"name '{name}' should have only alphanumeric characters, '-', '_', or '.'" - ) + print(f"name '{name}' should have only alphanumeric characters, '-', '_', or '.'") return 1 if enable and disable: @@ -168,9 +158,7 @@ def main(db, settings, test_args=None): # noqa: PLR0912 enabled = not disable client_id, client_secret = create_client(db, client_id, email, enabled) db.commit() - print_new_credentials( - client_id, client_secret, settings, sample_email=email, enabled=enabled - ) + print_new_credentials(client_id, client_secret, settings, sample_email=email, enabled=enabled) return 0 diff --git a/ctms/config.py b/ctms/config.py index 574c05cc..23af3bd2 100644 --- a/ctms/config.py +++ b/ctms/config.py @@ -3,7 +3,7 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated from pydantic import AfterValidator, Field, PostgresDsn from pydantic_settings import BaseSettings, SettingsConfigDict @@ -13,7 +13,7 @@ PostgresDsnStr = Annotated[PostgresDsn, AfterValidator(str)] -@lru_cache() +@lru_cache def get_version(): """ Return contents of version.json. @@ -49,11 +49,11 @@ class Settings(BaseSettings): logging_level: LogLevel = LogLevel.INFO sentry_debug: bool = False - fastapi_env: Optional[str] = Field(default=None, alias="FASTAPI_ENV") - sentry_dsn: Optional[AnyUrlString] = Field(default=None, alias="SENTRY_DSN") + fastapi_env: str | None = Field(default=None, alias="FASTAPI_ENV") + sentry_dsn: AnyUrlString | None = Field(default=None, alias="SENTRY_DSN") host: str = Field(default="0.0.0.0", alias="HOST") port: int = Field(default=8000, alias="PORT") - prometheus_pushgateway_url: Optional[str] = None + prometheus_pushgateway_url: str | None = None model_config = SettingsConfigDict(env_prefix="ctms_") diff --git a/ctms/crud.py b/ctms/crud.py index 9a5382ee..3fdbc1d5 100644 --- a/ctms/crud.py +++ b/ctms/crud.py @@ -2,8 +2,9 @@ import logging import uuid -from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast +from collections.abc import Callable +from datetime import UTC, datetime +from typing import Any, cast from pydantic import UUID4 from sqlalchemy import asc, or_, text @@ -59,18 +60,10 @@ def count_total_contacts(db: Session) -> int: This metadata is refreshed on `VACUUM` or `ANALYSIS` which is run regularly by default on our database instances. """ - result = db.execute( - text( - "SELECT reltuples AS estimate " - "FROM pg_class " - f"where relname = '{Email.__tablename__}'" - ) - ).scalar() + result = db.execute(text("SELECT reltuples AS estimate " "FROM pg_class " f"where relname = '{Email.__tablename__}'")).scalar() if result is None or result < 0: # Fall back to a full count if the estimate is not available. - result = db.execute( - text(f"SELECT COUNT(*) FROM {Email.__tablename__}") - ).scalar() + result = db.execute(text(f"SELECT COUNT(*) FROM {Email.__tablename__}")).scalar() if result is None: return -1 return int(result) @@ -81,19 +74,11 @@ def get_amo_by_email_id(db: Session, email_id: UUID4): def get_fxa_by_email_id(db: Session, email_id: UUID4): - return ( - db.query(FirefoxAccount) - .filter(FirefoxAccount.email_id == email_id) - .one_or_none() - ) + return db.query(FirefoxAccount).filter(FirefoxAccount.email_id == email_id).one_or_none() def get_mofo_by_email_id(db: Session, email_id: UUID4): - return ( - db.query(MozillaFoundationContact) - .filter(MozillaFoundationContact.email_id == email_id) - .one_or_none() - ) + return db.query(MozillaFoundationContact).filter(MozillaFoundationContact.email_id == email_id).one_or_none() def get_newsletters_by_email_id(db: Session, email_id: UUID4): @@ -131,7 +116,7 @@ def get_bulk_query(start_time, end_time, after_email_uuid, mofo_relevant): if mofo_relevant is False: filters.append( or_( - Email.mofo == None, + Email.mofo == None, # noqa: E711 Email.mofo.has(mofo_relevant=mofo_relevant), ) ) @@ -145,8 +130,8 @@ def get_bulk_contacts( start_time: datetime, end_time: datetime, limit: int, - mofo_relevant: Optional[bool] = None, - after_email_id: Optional[str] = None, + mofo_relevant: bool | None = None, + after_email_id: str | None = None, ): """Get all the data for a bulk batched set of contacts.""" after_email_uuid = None @@ -162,24 +147,20 @@ def get_bulk_contacts( for query_filter in filter_list: bulk_contacts = bulk_contacts.filter(query_filter) - bulk_contacts = ( - bulk_contacts.order_by(asc(Email.update_timestamp), asc(Email.email_id)) - .limit(limit) - .all() - ) + bulk_contacts = bulk_contacts.order_by(asc(Email.update_timestamp), asc(Email.email_id)).limit(limit).all() return [ContactSchema.from_email(email) for email in bulk_contacts] -def get_email(db: Session, email_id: UUID4) -> Optional[Email]: +def get_email(db: Session, email_id: UUID4) -> Email | None: """Get an Email and all related data.""" return cast( - Optional[Email], + Email | None, _contact_base_query(db).filter(Email.email_id == email_id).one_or_none(), ) -def get_contact_by_email_id(db: Session, email_id: UUID4) -> Optional[ContactSchema]: +def get_contact_by_email_id(db: Session, email_id: UUID4) -> ContactSchema | None: """Return a Contact object for a given email id""" email = get_email(db, email_id) if email is None: @@ -189,16 +170,16 @@ def get_contact_by_email_id(db: Session, email_id: UUID4) -> Optional[ContactSch def get_contacts_by_any_id( db: Session, - email_id: Optional[UUID4] = None, - primary_email: Optional[str] = None, - basket_token: Optional[UUID4] = None, - sfdc_id: Optional[str] = None, - mofo_contact_id: Optional[str] = None, - mofo_email_id: Optional[str] = None, - amo_user_id: Optional[str] = None, - fxa_id: Optional[str] = None, - fxa_primary_email: Optional[str] = None, -) -> List[ContactSchema]: + email_id: UUID4 | None = None, + primary_email: str | None = None, + basket_token: UUID4 | None = None, + sfdc_id: str | None = None, + mofo_contact_id: str | None = None, + mofo_email_id: str | None = None, + amo_user_id: str | None = None, + fxa_id: str | None = None, + fxa_primary_email: str | None = None, +) -> list[ContactSchema]: """ Get all the data for multiple contacts by ID as a list of Contacts. @@ -222,36 +203,26 @@ def get_contacts_by_any_id( if email_id is not None: statement = statement.filter(Email.email_id == email_id) if primary_email is not None: - statement = statement.filter_by( - primary_email_insensitive_comparator=primary_email - ) + statement = statement.filter_by(primary_email_insensitive_comparator=primary_email) if basket_token is not None: statement = statement.filter(Email.basket_token == str(basket_token)) if sfdc_id is not None: statement = statement.filter(Email.sfdc_id == sfdc_id) if mofo_contact_id is not None: - statement = statement.join(Email.mofo).filter( - MozillaFoundationContact.mofo_contact_id == mofo_contact_id - ) + statement = statement.join(Email.mofo).filter(MozillaFoundationContact.mofo_contact_id == mofo_contact_id) if mofo_email_id is not None: - statement = statement.join(Email.mofo).filter( - MozillaFoundationContact.mofo_email_id == mofo_email_id - ) + statement = statement.join(Email.mofo).filter(MozillaFoundationContact.mofo_email_id == mofo_email_id) if amo_user_id is not None: statement = statement.join(Email.amo).filter(AmoAccount.user_id == amo_user_id) if fxa_id is not None: statement = statement.join(Email.fxa).filter(FirefoxAccount.fxa_id == fxa_id) if fxa_primary_email is not None: - statement = statement.join(Email.fxa).filter_by( - fxa_primary_email_insensitive_comparator=fxa_primary_email - ) - emails = cast(List[Email], statement.all()) + statement = statement.join(Email.fxa).filter_by(fxa_primary_email_insensitive_comparator=fxa_primary_email) + emails = cast(list[Email], statement.all()) return [ContactSchema.from_email(email) for email in emails] -def create_amo( - db: Session, email_id: UUID4, amo: AddOnsInSchema -) -> Optional[AmoAccount]: +def create_amo(db: Session, email_id: UUID4, amo: AddOnsInSchema) -> AmoAccount | None: if amo.is_default(): return None db_amo = AmoAccount(email_id=email_id, **amo.model_dump()) @@ -259,7 +230,7 @@ def create_amo( return db_amo -def create_or_update_amo(db: Session, email_id: UUID4, amo: Optional[AddOnsInSchema]): +def create_or_update_amo(db: Session, email_id: UUID4, amo: AddOnsInSchema | None): if not amo or amo.is_default(): db.query(AmoAccount).filter(AmoAccount.email_id == email_id).delete() return @@ -267,9 +238,7 @@ def create_or_update_amo(db: Session, email_id: UUID4, amo: Optional[AddOnsInSch # Providing update timestamp updated_amo = UpdatedAddOnsInSchema(**amo.model_dump()) stmt = insert(AmoAccount).values(email_id=email_id, **updated_amo.model_dump()) - stmt = stmt.on_conflict_do_update( - index_elements=[AmoAccount.email_id], set_=updated_amo.model_dump() - ) + stmt = stmt.on_conflict_do_update(index_elements=[AmoAccount.email_id], set_=updated_amo.model_dump()) db.execute(stmt) @@ -283,15 +252,11 @@ def create_or_update_email(db: Session, email: EmailPutSchema): updated_email = UpdatedEmailPutSchema(**email.model_dump()) stmt = insert(Email).values(**updated_email.model_dump()) - stmt = stmt.on_conflict_do_update( - index_elements=[Email.email_id], set_=updated_email.model_dump() - ) + stmt = stmt.on_conflict_do_update(index_elements=[Email.email_id], set_=updated_email.model_dump()) db.execute(stmt) -def create_fxa( - db: Session, email_id: UUID4, fxa: FirefoxAccountsInSchema -) -> Optional[FirefoxAccount]: +def create_fxa(db: Session, email_id: UUID4, fxa: FirefoxAccountsInSchema) -> FirefoxAccount | None: if fxa.is_default(): return None db_fxa = FirefoxAccount(email_id=email_id, **fxa.model_dump()) @@ -299,9 +264,7 @@ def create_fxa( return db_fxa -def create_or_update_fxa( - db: Session, email_id: UUID4, fxa: Optional[FirefoxAccountsInSchema] -): +def create_or_update_fxa(db: Session, email_id: UUID4, fxa: FirefoxAccountsInSchema | None): if not fxa or fxa.is_default(): (db.query(FirefoxAccount).filter(FirefoxAccount.email_id == email_id).delete()) return @@ -309,15 +272,11 @@ def create_or_update_fxa( updated_fxa = UpdatedFirefoxAccountsInSchema(**fxa.model_dump()) stmt = insert(FirefoxAccount).values(email_id=email_id, **updated_fxa.model_dump()) - stmt = stmt.on_conflict_do_update( - index_elements=[FirefoxAccount.email_id], set_=updated_fxa.model_dump() - ) + stmt = stmt.on_conflict_do_update(index_elements=[FirefoxAccount.email_id], set_=updated_fxa.model_dump()) db.execute(stmt) -def create_mofo( - db: Session, email_id: UUID4, mofo: MozillaFoundationInSchema -) -> Optional[MozillaFoundationContact]: +def create_mofo(db: Session, email_id: UUID4, mofo: MozillaFoundationInSchema) -> MozillaFoundationContact | None: if mofo.is_default(): return None db_mofo = MozillaFoundationContact(email_id=email_id, **mofo.model_dump()) @@ -325,28 +284,16 @@ def create_mofo( return db_mofo -def create_or_update_mofo( - db: Session, email_id: UUID4, mofo: Optional[MozillaFoundationInSchema] -): +def create_or_update_mofo(db: Session, email_id: UUID4, mofo: MozillaFoundationInSchema | None): if not mofo or mofo.is_default(): - ( - db.query(MozillaFoundationContact) - .filter(MozillaFoundationContact.email_id == email_id) - .delete() - ) + (db.query(MozillaFoundationContact).filter(MozillaFoundationContact.email_id == email_id).delete()) return - stmt = insert(MozillaFoundationContact).values( - email_id=email_id, **mofo.model_dump() - ) - stmt = stmt.on_conflict_do_update( - index_elements=[MozillaFoundationContact.email_id], set_=mofo.model_dump() - ) + stmt = insert(MozillaFoundationContact).values(email_id=email_id, **mofo.model_dump()) + stmt = stmt.on_conflict_do_update(index_elements=[MozillaFoundationContact.email_id], set_=mofo.model_dump()) db.execute(stmt) -def create_newsletter( - db: Session, email_id: UUID4, newsletter: NewsletterInSchema -) -> Optional[Newsletter]: +def create_newsletter(db: Session, email_id: UUID4, newsletter: NewsletterInSchema) -> Newsletter | None: if newsletter.is_default(): return None db_newsletter = Newsletter(email_id=email_id, **newsletter.model_dump()) @@ -354,18 +301,12 @@ def create_newsletter( return db_newsletter -def create_or_update_newsletters( - db: Session, email_id: UUID4, newsletters: List[NewsletterInSchema] -): +def create_or_update_newsletters(db: Session, email_id: UUID4, newsletters: list[NewsletterInSchema]): # Start by deleting the existing newsletters that are not specified as input. # We delete instead of set subscribed=False, because we want an idempotent # round-trip of PUT/GET at the API level. - names = [ - newsletter.name for newsletter in newsletters if not newsletter.is_default() - ] - db.query(Newsletter).filter( - Newsletter.email_id == email_id, Newsletter.name.notin_(names) - ).delete( + names = [newsletter.name for newsletter in newsletters if not newsletter.is_default()] + db.query(Newsletter).filter(Newsletter.email_id == email_id, Newsletter.name.notin_(names)).delete( # Do not bother synchronizing objects in the session. # We won't have stale objects because the next upsert query will update # the other remaining objects (equivalent to `Waitlist.name.in_(names)`). @@ -373,9 +314,7 @@ def create_or_update_newsletters( ) if newsletters: - stmt = insert(Newsletter).values( - [{"email_id": email_id, **n.model_dump()} for n in newsletters] - ) + stmt = insert(Newsletter).values([{"email_id": email_id, **n.model_dump()} for n in newsletters]) stmt = stmt.on_conflict_do_update( constraint="uix_email_name", set_={ @@ -387,9 +326,7 @@ def create_or_update_newsletters( db.execute(stmt) -def create_waitlist( - db: Session, email_id: UUID4, waitlist: WaitlistInSchema -) -> Optional[Waitlist]: +def create_waitlist(db: Session, email_id: UUID4, waitlist: WaitlistInSchema) -> Waitlist | None: if waitlist.is_default(): return None db_waitlist = Waitlist(email_id=email_id, **waitlist.model_dump()) @@ -397,29 +334,21 @@ def create_waitlist( return db_waitlist -def create_or_update_waitlists( - db: Session, email_id: UUID4, waitlists: List[WaitlistInSchema] -): +def create_or_update_waitlists(db: Session, email_id: UUID4, waitlists: list[WaitlistInSchema]): # Start by deleting the existing waitlists that are not specified as input. # We delete instead of set subscribed=False, because we want an idempotent # round-trip of PUT/GET at the API level. # Note: the contact is marked as pending synchronization at the API routers level. names = [waitlist.name for waitlist in waitlists if not waitlist.is_default()] - db.query(Waitlist).filter( - Waitlist.email_id == email_id, Waitlist.name.notin_(names) - ).delete( + db.query(Waitlist).filter(Waitlist.email_id == email_id, Waitlist.name.notin_(names)).delete( # Do not bother synchronizing objects in the session. # We won't have stale objects because the next upsert query will update # the other remaining objects (equivalent to `Waitlist.name.in_(names)`). synchronize_session=False ) - waitlists_to_upsert = [ - WaitlistInSchema(**waitlist.model_dump()) for waitlist in waitlists - ] + waitlists_to_upsert = [WaitlistInSchema(**waitlist.model_dump()) for waitlist in waitlists] if waitlists_to_upsert: - stmt = insert(Waitlist).values( - [{"email_id": email_id, **wl.model_dump()} for wl in waitlists] - ) + stmt = insert(Waitlist).values([{"email_id": email_id, **wl.model_dump()} for wl in waitlists]) stmt = stmt.on_conflict_do_update( constraint="uix_wl_email_name", set_={ @@ -435,7 +364,7 @@ def create_contact( db: Session, email_id: UUID4, contact: ContactInSchema, - metrics: Optional[Dict], + metrics: dict | None, ): create_email(db, contact.email) if contact.amo: @@ -452,9 +381,7 @@ def create_contact( create_waitlist(db, email_id, waitlist) -def create_or_update_contact( - db: Session, email_id: UUID4, contact: ContactPutSchema, metrics: Optional[Dict] -): +def create_or_update_contact(db: Session, email_id: UUID4, contact: ContactPutSchema, metrics: dict | None): create_or_update_email(db, contact.email) create_or_update_amo(db, email_id, contact.amo) create_or_update_fxa(db, email_id, contact.fxa) @@ -466,9 +393,7 @@ def create_or_update_contact( def delete_contact(db: Session, email_id: UUID4): db.query(AmoAccount).filter(AmoAccount.email_id == email_id).delete() - db.query(MozillaFoundationContact).filter( - MozillaFoundationContact.email_id == email_id - ).delete() + db.query(MozillaFoundationContact).filter(MozillaFoundationContact.email_id == email_id).delete() db.query(Newsletter).filter(Newsletter.email_id == email_id).delete() db.query(Waitlist).filter(Waitlist.email_id == email_id).delete() db.query(FirefoxAccount).filter(FirefoxAccount.email_id == email_id).delete() @@ -484,7 +409,7 @@ def _update_orm(orm: Base, update_dict: dict): def update_contact( # noqa: PLR0912 - db: Session, email: Email, update_data: dict, metrics: Optional[Dict] + db: Session, email: Email, update_data: dict, metrics: dict | None ) -> None: """Update an existing contact using a sparse update dictionary""" email_id = email.email_id @@ -492,9 +417,7 @@ def update_contact( # noqa: PLR0912 if "email" in update_data: _update_orm(email, update_data["email"]) - simple_groups: Dict[ - str, Tuple[Callable[[Session, UUID4, Any], Optional[Base]], Type[Any]] - ] = { + simple_groups: dict[str, tuple[Callable[[Session, UUID4, Any], Base | None], type[Any]]] = { "amo": (create_amo, AddOnsInSchema), "fxa": (create_fxa, FirefoxAccountsInSchema), "mofo": (create_mofo, MozillaFoundationInSchema), @@ -527,9 +450,7 @@ def update_contact( # noqa: PLR0912 if nl_update["name"] in existing: _update_orm(existing[nl_update["name"]], nl_update) elif nl_update.get("subscribed", True): - new = create_newsletter( - db, email_id, NewsletterInSchema(**nl_update) - ) + new = create_newsletter(db, email_id, NewsletterInSchema(**nl_update)) email.newsletters.append(new) existing = {} @@ -551,7 +472,7 @@ def update_contact( # noqa: PLR0912 email.waitlists.append(new) # On any PATCH event, the central/email table's time is updated as well. - _update_orm(email, {"update_timestamp": datetime.now(timezone.utc)}) + _update_orm(email, {"update_timestamp": datetime.now(UTC)}) def create_api_client(db: Session, api_client: ApiClientSchema, secret): @@ -564,14 +485,8 @@ def get_api_client_by_id(db: Session, client_id: str): return db.query(ApiClient).filter(ApiClient.client_id == client_id).one_or_none() -def get_active_api_client_ids(db: Session) -> List[str]: - rows = ( - db.query(ApiClient) - .filter(ApiClient.enabled.is_(True)) - .options(load_only(ApiClient.client_id)) - .order_by(ApiClient.client_id) - .all() - ) +def get_active_api_client_ids(db: Session) -> list[str]: + rows = db.query(ApiClient).filter(ApiClient.enabled.is_(True)).options(load_only(ApiClient.client_id)).order_by(ApiClient.client_id).all() return [row.client_id for row in rows] @@ -591,10 +506,5 @@ def get_contacts_from_newsletter(dbsession, newsletter_name): def get_contacts_from_waitlist(dbsession, waitlist_name): - entries = ( - dbsession.query(Waitlist) - .options(joinedload(Waitlist.email)) - .filter(Waitlist.name == waitlist_name) - .all() - ) + entries = dbsession.query(Waitlist).options(joinedload(Waitlist.email)).filter(Waitlist.name == waitlist_name).all() return entries diff --git a/ctms/dependencies.py b/ctms/dependencies.py index 8dbc54df..b084cd40 100644 --- a/ctms/dependencies.py +++ b/ctms/dependencies.py @@ -1,6 +1,5 @@ from datetime import timedelta from functools import lru_cache -from typing import Dict, Union from fastapi import Depends, HTTPException, Request from sqlalchemy.orm import Session @@ -13,7 +12,7 @@ from ctms.schemas import ApiClientSchema -@lru_cache() +@lru_cache def get_settings() -> Settings: return Settings() @@ -28,7 +27,7 @@ def get_db(): # pragma: no cover def get_token_settings( settings: Settings = Depends(get_settings), -) -> Dict[str, Union[str, timedelta]]: +) -> dict[str, str | timedelta]: return { "expires_delta": settings.token_expiration, "secret_key": settings.secret_key, @@ -75,9 +74,7 @@ def get_api_client( return api_client -def get_enabled_api_client( - request: Request, api_client: ApiClientSchema = Depends(get_api_client) -): +def get_enabled_api_client(request: Request, api_client: ApiClientSchema = Depends(get_api_client)): auth_info = auth_info_context.get() auth_info.clear() if not auth_info.get("client_id"): @@ -90,7 +87,7 @@ def get_enabled_api_client( return api_client -async def get_json(request: Request) -> Dict: +async def get_json(request: Request) -> dict: """ Get the request body as JSON. @@ -98,5 +95,5 @@ async def get_json(request: Request) -> Dict: before this dependency is resolved. If the body is form-encoded, it will raise an unknown exception. """ - the_json: Dict = await request.json() + the_json: dict = await request.json() return the_json diff --git a/ctms/log.py b/ctms/log.py index 6259875b..51dbb947 100644 --- a/ctms/log.py +++ b/ctms/log.py @@ -2,10 +2,6 @@ import logging import sys -from typing import Any, Dict, Optional - -from fastapi import Request -from starlette.routing import Match from ctms.auth import auth_info_context from ctms.config import Settings @@ -18,7 +14,7 @@ class AuthInfoLogFilter(logging.Filter): def filter(self, record: "logging.LogRecord") -> bool: # All records attributes will be logged as fields. - auth_info = auth_info_context.get() + auth_info = auth_info_context.get({}) for k, v in auth_info.items(): setattr(record, k, v) # MozLog also recommends using `uid` for user ids. @@ -66,9 +62,7 @@ def filter(self, record: "logging.LogRecord") -> bool: "uvicorn": {"level": logging.INFO}, "uvicorn.access": {"handlers": ["null"], "propagate": False}, "sqlalchemy.engine": { - "level": settings.logging_level.name - if settings.log_sqlalchemy - else logging.WARNING, + "level": settings.logging_level.name if settings.log_sqlalchemy else logging.WARNING, "propagate": False, }, }, diff --git a/ctms/metrics.py b/ctms/metrics.py index cb19df07..b2466480 100644 --- a/ctms/metrics.py +++ b/ctms/metrics.py @@ -1,21 +1,19 @@ """Prometheus metrics for instrumentation and monitoring.""" from itertools import product -from typing import Any, Optional, Type, cast +from typing import Any, cast from fastapi import FastAPI from fastapi.security import HTTPBasic from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram from prometheus_client.utils import INF from sqlalchemy.orm import Session -from starlette.routing import Match, Route +from starlette.routing import Route from ctms.auth import OAuth2ClientCredentials from ctms.crud import get_active_api_client_ids -METRICS_PARAMS: dict[ - str, tuple[Type[Counter] | Type[Histogram] | type[Gauge], dict] -] = { +METRICS_PARAMS: dict[str, tuple[type[Counter] | type[Histogram] | type[Gauge], dict]] = { "requests": ( Counter, { @@ -71,8 +69,13 @@ def set_metrics(metrics: Any) -> None: METRICS = metrics -get_metrics_registry = lambda: METRICS_REGISTRY -get_metrics = lambda: METRICS +def get_metrics_registry() -> CollectorRegistry: + return METRICS_REGISTRY + + +def get_metrics() -> Any: + return METRICS + oauth2_scheme = OAuth2ClientCredentials(tokenUrl="token") token_scheme = HTTPBasic(auto_error=False) @@ -87,9 +90,7 @@ def init_metrics(registry: CollectorRegistry) -> dict[str, Counter | Histogram | return metrics -def init_metrics_labels( - dbsession: Session, app: FastAPI, metrics: dict[str, Counter | Histogram] -) -> None: +def init_metrics_labels(dbsession: Session, app: FastAPI, metrics: dict[str, Counter | Histogram]) -> None: """Create the initial metric combinations.""" openapi = app.openapi() client_ids = get_active_api_client_ids(dbsession) or ["none"] @@ -108,9 +109,7 @@ def init_metrics_labels( status_codes = [] for method_lower, mspec in api_spec.items(): if method_lower.upper() in methods: - status_codes.extend( - [int(code) for code in list(mspec.get("responses", [200]))] - ) + status_codes.extend([int(code) for code in list(mspec.get("responses", [200]))]) is_api |= "security" in mspec elif path == "/": status_codes = [307] @@ -129,17 +128,15 @@ def init_metrics_labels( for api_combo in product(methods, status_code_families): method, status_code_family = api_combo for client_id in client_ids: - api_request_metric.labels( - method, path, client_id, status_code_family - ) + api_request_metric.labels(method, path, client_id, status_code_family) def emit_response_metrics( - path_template: Optional[str], + path_template: str | None, method: str, duration_s: float, status_code: int, - client_id: Optional[str], + client_id: str | None, metrics: dict[str, Counter | Histogram], ) -> None: """Emit metrics for a response.""" diff --git a/ctms/models.py b/ctms/models.py index 23adde3a..6d6d7cdb 100644 --- a/ctms/models.py +++ b/ctms/models.py @@ -54,9 +54,7 @@ class Email(Base, TimestampMixin): __tablename__ = "emails" __mapper_args__ = {"eager_defaults": True} - email_id = mapped_column( - UUID, primary_key=True, server_default="uuid_generate_v4()" - ) + email_id = mapped_column(UUID, primary_key=True, server_default="uuid_generate_v4()") primary_email = mapped_column(String(255), unique=True, nullable=False) basket_token = mapped_column(String(255), unique=True) sfdc_id = mapped_column(String(255), index=True) @@ -69,17 +67,11 @@ class Email(Base, TimestampMixin): has_opted_out_of_email = mapped_column(Boolean) unsubscribe_reason = mapped_column(Text) - newsletters = relationship( - "Newsletter", back_populates="email", order_by="Newsletter.name" - ) - waitlists = relationship( - "Waitlist", back_populates="email", order_by="Waitlist.name" - ) + newsletters = relationship("Newsletter", back_populates="email", order_by="Newsletter.name") + waitlists = relationship("Waitlist", back_populates="email", order_by="Waitlist.name") fxa = relationship("FirefoxAccount", back_populates="email", uselist=False) amo = relationship("AmoAccount", back_populates="email", uselist=False) - mofo = relationship( - "MozillaFoundationContact", back_populates="email", uselist=False - ) + mofo = relationship("MozillaFoundationContact", back_populates="email", uselist=False) # Class Comparators @hybrid_property @@ -105,9 +97,7 @@ class Newsletter(Base, TimestampMixin): __tablename__ = "newsletters" id = mapped_column(Integer, primary_key=True) - email_id: Mapped[UUID4] = mapped_column( - UUID(as_uuid=True), ForeignKey(Email.email_id), nullable=False - ) + email_id: Mapped[UUID4] = mapped_column(UUID(as_uuid=True), ForeignKey(Email.email_id), nullable=False) name = mapped_column(String(255), nullable=False) subscribed = mapped_column(Boolean) format = mapped_column(String(1)) @@ -124,9 +114,7 @@ class Waitlist(Base, TimestampMixin): __tablename__ = "waitlists" id = mapped_column(Integer, primary_key=True) - email_id: Mapped[UUID4] = mapped_column( - UUID(as_uuid=True), ForeignKey(Email.email_id), nullable=False - ) + email_id: Mapped[UUID4] = mapped_column(UUID(as_uuid=True), ForeignKey(Email.email_id), nullable=False) name = mapped_column(String(255), nullable=False) source = mapped_column(Text) subscribed = mapped_column(Boolean, nullable=False, default=True) @@ -143,9 +131,7 @@ class FirefoxAccount(Base, TimestampMixin): id = mapped_column(Integer, primary_key=True) fxa_id = mapped_column(String(255), unique=True) - email_id = mapped_column( - UUID(as_uuid=True), ForeignKey(Email.email_id), unique=True, nullable=False - ) + email_id = mapped_column(UUID(as_uuid=True), ForeignKey(Email.email_id), unique=True, nullable=False) primary_email = mapped_column(String(255), index=True) created_date = mapped_column(String(50)) lang = mapped_column(String(255)) @@ -173,9 +159,7 @@ class AmoAccount(Base, TimestampMixin): __tablename__ = "amo" id = mapped_column(Integer, primary_key=True) - email_id = mapped_column( - UUID(as_uuid=True), ForeignKey(Email.email_id), unique=True, nullable=False - ) + email_id = mapped_column(UUID(as_uuid=True), ForeignKey(Email.email_id), unique=True, nullable=False) add_on_ids = mapped_column(String(500)) display_name = mapped_column(String(255)) email_opt_in = mapped_column(Boolean) @@ -206,9 +190,7 @@ class MozillaFoundationContact(Base, TimestampMixin): __tablename__ = "mofo" id = mapped_column(Integer, primary_key=True) - email_id = mapped_column( - UUID(as_uuid=True), ForeignKey(Email.email_id), unique=True, nullable=False - ) + email_id = mapped_column(UUID(as_uuid=True), ForeignKey(Email.email_id), unique=True, nullable=False) mofo_email_id = mapped_column(String(255), unique=True) mofo_contact_id = mapped_column(String(255), index=True) mofo_relevant = mapped_column(Boolean) diff --git a/ctms/routers/contacts.py b/ctms/routers/contacts.py index e8d0d988..85afc54e 100644 --- a/ctms/routers/contacts.py +++ b/ctms/routers/contacts.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Dict, List, Literal, Optional, Union +from typing import Annotated, Literal from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response @@ -56,15 +56,15 @@ def get_contact_or_404(db: Session, email_id) -> ContactSchema: def all_ids( - email_id: Optional[UUID] = None, - primary_email: Optional[str] = None, - basket_token: Optional[UUID] = None, - sfdc_id: Optional[str] = None, - mofo_contact_id: Optional[str] = None, - mofo_email_id: Optional[str] = None, - amo_user_id: Optional[str] = None, - fxa_id: Optional[str] = None, - fxa_primary_email: Optional[str] = None, + email_id: UUID | None = None, + primary_email: str | None = None, + basket_token: UUID | None = None, + sfdc_id: str | None = None, + mofo_contact_id: str | None = None, + mofo_email_id: str | None = None, + amo_user_id: str | None = None, + fxa_id: str | None = None, + fxa_primary_email: str | None = None, ): """Alternate IDs, injected as a dependency.""" return { @@ -122,11 +122,7 @@ def get_bulk_contacts_by_timestamp_or_4xx( ) next_url = ( - f"{get_settings().server_prefix}/updates?" - f"start={start_time.isoformat()}" - f"&end={end_time.isoformat()}" - f"&limit={limit}" - f"&after={after_encoded} " + f"{get_settings().server_prefix}/updates?start={start_time.isoformat()}&end={end_time.isoformat()}&limit={limit}&after={after_encoded} " ) return CTMSBulkResponse( @@ -142,7 +138,7 @@ def get_bulk_contacts_by_timestamp_or_4xx( @router.get( "/ctms", summary="Get all contacts matching alternate IDs", - response_model=List[CTMSResponse], + response_model=list[CTMSResponse], responses={ 400: {"model": BadRequestResponse}, 401: {"model": UnauthorizedResponse}, @@ -151,14 +147,12 @@ def get_bulk_contacts_by_timestamp_or_4xx( ) def read_ctms_by_any_id( request: Request, - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], ids=Depends(all_ids), ): if not any(ids.values()): - detail = ( - f"No identifiers provided, at least one is needed: {', '.join(ids.keys())}" - ) + detail = f"No identifiers provided, at least one is needed: {', '.join(ids.keys())}" raise HTTPException(status_code=400, detail=detail) contacts = get_contacts_by_any_id(db, **ids) return [CTMSResponse(**contact.model_dump()) for contact in contacts] @@ -176,9 +170,9 @@ def read_ctms_by_any_id( ) def read_ctms_by_email_id( request: Request, - email_id: UUID = Path(..., title="The Email ID"), - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), + email_id: Annotated[UUID, Path(..., title="The Email ID")], + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], ): resp = get_ctms_response_or_404(db, email_id) return resp @@ -200,9 +194,9 @@ def create_ctms_contact( contact: ContactInSchema, request: Request, response: Response, - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), - content_json: Optional[Dict] = Depends(get_json), + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], + content_json: Annotated[dict | None, Depends(get_json)], ): contact.email.email_id = contact.email.email_id or uuid4() email_id = contact.email.email_id @@ -241,10 +235,10 @@ def create_or_update_ctms_contact( contact: ContactPutSchema, request: Request, response: Response, - email_id: UUID = Path(..., title="The Email ID"), - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), - content_json: Optional[Dict] = Depends(get_json), + email_id: Annotated[UUID, Path(..., title="The Email ID")], + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], + content_json: Annotated[dict | None, Depends(get_json)], ): if contact.email.email_id: if contact.email.email_id != email_id: @@ -285,16 +279,12 @@ def partial_update_ctms_contact( contact: ContactPatchSchema, request: Request, response: Response, - email_id: UUID = Path(..., title="The Email ID"), - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), - content_json: Optional[Dict] = Depends(get_json), + email_id: Annotated[UUID, Path(..., title="The Email ID")], + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], + content_json: Annotated[dict | None, Depends(get_json)], ): - if ( - contact.email - and getattr(contact.email, "email_id") - and contact.email.email_id != email_id - ): + if contact.email and contact.email.email_id and contact.email.email_id != email_id: raise HTTPException( status_code=422, detail="cannot change email_id", @@ -310,10 +300,7 @@ def partial_update_ctms_contact( if isinstance(e, IntegrityError): raise HTTPException( status_code=409, - detail=( - "Contact with primary_email, basket_token, mofo_email_id," - " or fxa_id already exists" - ), + detail="Contact with primary_email, basket_token, mofo_email_id, or fxa_id already exists", ) from e raise response.status_code = 200 @@ -323,7 +310,7 @@ def partial_update_ctms_contact( @router.delete( "/ctms/{primary_email}", summary="Delete all contact information from primary email", - response_model=List[IdentityResponse], + response_model=list[IdentityResponse], responses={ 404: {"model": NotFoundResponse}, }, @@ -331,8 +318,8 @@ def partial_update_ctms_contact( ) def delete_contact_by_primary_email( primary_email: str, - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], ): ids = all_ids(primary_email=primary_email.lower()) contacts = get_contacts_by_any_id(db, **ids) @@ -358,12 +345,12 @@ def delete_contact_by_primary_email( ) def read_ctms_in_bulk_by_timestamps_and_limit( start: datetime, - end: Optional[Union[datetime, Literal[""]]] = None, - limit: Optional[Union[int, Literal[""]]] = None, - after: Optional[str] = None, - mofo_relevant: Optional[Union[bool, Literal[""]]] = None, - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), + end: datetime | Literal[""] | None = None, + limit: int | Literal[""] | None = None, + after: str | None = None, + mofo_relevant: bool | Literal[""] | None = None, + db: Session = Depends(get_db), # noqa: FAST002, parameter without default + api_client: ApiClientSchema = Depends(get_enabled_api_client), # noqa: FAST002, parameter without default ): try: bulk_request = BulkRequestSchema( @@ -376,15 +363,13 @@ def read_ctms_in_bulk_by_timestamps_and_limit( return get_bulk_contacts_by_timestamp_or_4xx(db=db, **bulk_request.model_dump()) except ValidationError as e: detail = {"errors": json.loads(e.json())} - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail - ) from e + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail) from e @router.get( "/identities", summary="Get identities associated with alternate IDs", - response_model=List[IdentityResponse], + response_model=list[IdentityResponse], responses={ 400: {"model": BadRequestResponse}, 401: {"model": UnauthorizedResponse}, @@ -392,14 +377,12 @@ def read_ctms_in_bulk_by_timestamps_and_limit( tags=["Private"], ) def read_identities( - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], ids=Depends(all_ids), ): if not any(ids.values()): - detail = ( - f"No identifiers provided, at least one is needed: {', '.join(ids.keys())}" - ) + detail = f"No identifiers provided, at least one is needed: {', '.join(ids.keys())}" raise HTTPException(status_code=400, detail=detail) contacts = get_contacts_by_any_id(db, **ids) return [contact.as_identity_response() for contact in contacts] @@ -416,9 +399,9 @@ def read_identities( tags=["Private"], ) def read_identity( - email_id: UUID = Path(..., title="The email ID"), - db: Session = Depends(get_db), - api_client: ApiClientSchema = Depends(get_enabled_api_client), + email_id: Annotated[UUID, Path(..., title="The Email ID")], + db: Annotated[Session, Depends(get_db)], + api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)], ): contact = get_contact_or_404(db, email_id) return contact.as_identity_response() diff --git a/ctms/routers/platform.py b/ctms/routers/platform.py index 0975c32d..c562dd34 100644 --- a/ctms/routers/platform.py +++ b/ctms/routers/platform.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Annotated from dockerflow import checks as dockerflow_checks from fastapi import APIRouter, Depends, HTTPException, Request, Response @@ -48,15 +48,13 @@ def root(): ) def login( request: Request, - db: Session = Depends(get_db), - form_data: OAuth2ClientCredentialsRequestForm = Depends(), - basic_credentials: Optional[HTTPBasicCredentials] = Depends(token_scheme), + db: Annotated[Session, Depends(get_db)], + form_data: Annotated[OAuth2ClientCredentialsRequestForm, Depends()], + basic_credentials: Annotated[HTTPBasicCredentials | None, Depends(token_scheme)], token_settings=Depends(get_token_settings), ): auth_info = auth_info_context.get() - failed_auth = HTTPException( - status_code=400, detail="Incorrect username or password" - ) + failed_auth = HTTPException(status_code=400, detail="Incorrect username or password") if form_data.client_id and form_data.client_secret: client_id = form_data.client_id @@ -82,9 +80,7 @@ def login( auth_info["token_fail"] = "Bad credentials" raise failed_auth - access_token = create_access_token( - data={"sub": f"api_client:{client_id}"}, **token_settings - ) + access_token = create_access_token(data={"sub": f"api_client:{client_id}"}, **token_settings) return { "access_token": access_token, "token_type": "bearer", @@ -99,9 +95,7 @@ def database(): with SessionLocal() as db: alive = ping(db) if not alive: - result.append( - dockerflow_checks.Error("Database not reachable", id="db.0001") - ) + result.append(dockerflow_checks.Error("Database not reachable", id="db.0001")) return result # Report number of contacts in the database. # Sending the metric in this heartbeat endpoint is simpler than reporting @@ -120,7 +114,7 @@ def database(): @router.get("/__crash__", tags=["Platform"], include_in_schema=False) -def crash(api_client: ApiClientSchema = Depends(get_enabled_api_client)): +def crash(api_client: Annotated[ApiClientSchema, Depends(get_enabled_api_client)]): """Raise an exception to test Sentry integration.""" raise RuntimeError("Test exception handling") diff --git a/ctms/schemas/__init__.py b/ctms/schemas/__init__.py index fb5a31d0..edd4e43d 100644 --- a/ctms/schemas/__init__.py +++ b/ctms/schemas/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 -- Allow unused imports from .addons import AddOnsInSchema, AddOnsSchema, UpdatedAddOnsInSchema from .api_client import ApiClientSchema from .bulk import BulkRequestSchema diff --git a/ctms/schemas/addons.py b/ctms/schemas/addons.py index e6c38ff3..587d0cc8 100644 --- a/ctms/schemas/addons.py +++ b/ctms/schemas/addons.py @@ -1,5 +1,4 @@ -from datetime import date, datetime, timezone -from typing import Optional +from datetime import UTC, date, datetime from pydantic import ConfigDict, Field @@ -19,12 +18,12 @@ class AddOnsBase(ComparableBase): contact data on the return from Salesforce. """ - add_on_ids: Optional[str] = Field( + add_on_ids: str | None = Field( default=None, description="Comma-separated list of add-ons for account, AMO_Add_On_ID_s__c in Salesforce", examples=["add-on-1,add-on-2"], ) - display_name: Optional[str] = Field( + display_name: str | None = Field( default=None, max_length=255, description="Display name on AMO, AMO_Display_Name__c in Salesforce", @@ -34,24 +33,24 @@ class AddOnsBase(ComparableBase): default=False, description="Account has opted into emails, AMO_Email_Opt_In__c in Salesforce", ) - language: Optional[str] = Field( + language: str | None = Field( default=None, max_length=5, description="Account language, AMO_Language__c in Salesforce", examples=["en"], ) - last_login: Optional[date] = Field( + last_login: date | None = Field( default=None, description="Last login date on addons.mozilla.org, AMO_Last_Login__c in Salesforce", examples=["2021-01-28"], ) - location: Optional[str] = Field( + location: str | None = Field( default=None, max_length=255, description="Free-text location on AMO, AMO_Location__c in Salesforce", examples=["California"], ) - profile_url: Optional[str] = Field( + profile_url: str | None = Field( default=None, max_length=40, description="AMO profile URL, AMO_Profile_URL__c in Salesforce", @@ -62,13 +61,13 @@ class AddOnsBase(ComparableBase): description="True if user is from an Add-on sync, AMO_User__c in Salesforce", examples=[True], ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, max_length=40, description="User ID on AMO, AMO_User_ID__c in Salesforce", examples=["98765"], ) - username: Optional[str] = Field( + username: str | None = Field( default=None, max_length=100, description="Username on AMO, AMO_Username__c in Salesforce", @@ -83,19 +82,19 @@ class AddOnsBase(ComparableBase): class UpdatedAddOnsInSchema(AddOnsInSchema): update_timestamp: ZeroOffsetDatetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="AMO data update timestamp", examples=["2021-01-28T21:26:57.511+00:00"], ) class AddOnsSchema(AddOnsBase): - create_timestamp: Optional[ZeroOffsetDatetime] = Field( + create_timestamp: ZeroOffsetDatetime | None = Field( default=None, description="AMO data creation timestamp", examples=["2020-12-05T19:21:50.908000+00:00"], ) - update_timestamp: Optional[ZeroOffsetDatetime] = Field( + update_timestamp: ZeroOffsetDatetime | None = Field( default=None, description="AMO data update timestamp", examples=["2021-02-04T15:36:57.511000+00:00"], diff --git a/ctms/schemas/bulk.py b/ctms/schemas/bulk.py index 8010414a..030156f1 100644 --- a/ctms/schemas/bulk.py +++ b/ctms/schemas/bulk.py @@ -1,6 +1,6 @@ import base64 -from datetime import datetime, timezone -from typing import Literal, Optional, Tuple, Union +from datetime import UTC, datetime +from typing import Literal import dateutil.parser from pydantic import Field, field_validator @@ -15,20 +15,16 @@ class BulkRequestSchema(ComparableBase): start_time: datetime - end_time: Optional[Union[datetime, Literal[""]]] = Field( - default=None, validate_default=True - ) + end_time: datetime | Literal[""] | None = Field(default=None, validate_default=True) @field_validator("end_time", mode="before") @classmethod def end_time_must_not_be_blank(cls, value): if value in BLANK_VALS: - return datetime.now(timezone.utc) + return datetime.now(UTC) return value - limit: Optional[Union[int, Literal[""]]] = Field( - default=None, validate_default=True - ) + limit: int | Literal[""] | None = Field(default=None, validate_default=True) @field_validator("limit", mode="before") @classmethod @@ -41,9 +37,7 @@ def limit_must_adhere_to_validations(cls, value): raise ValueError('"limit" should be less than or equal to 1000') return value - mofo_relevant: Optional[Union[bool, Literal[""]]] = Field( - default=None, validate_default=True - ) + mofo_relevant: bool | Literal[""] | None = Field(default=None, validate_default=True) @field_validator("mofo_relevant", mode="before") @classmethod @@ -52,7 +46,7 @@ def mofo_relevant_must_not_be_blank(cls, value): return None # Default return value - after: Optional[str] = Field(default=None, validate_default=True) + after: str | None = Field(default=None, validate_default=True) @field_validator("after", mode="before") def after_must_be_base64_decodable(cls, value): @@ -60,16 +54,12 @@ def after_must_be_base64_decodable(cls, value): return None # Default try: str_decode = base64.urlsafe_b64decode(value) - return str( - str_decode.decode("utf-8") - ) # 'after' should be decodable otherwise err and invalid + return str(str_decode.decode("utf-8")) # 'after' should be decodable otherwise err and invalid except Exception as e: - raise ValueError( - "'after' param validation error when decoding value." - ) from e + raise ValueError("'after' param validation error when decoding value.") from e @staticmethod - def extractor_for_bulk_encoded_details(after: str) -> Tuple[str, datetime]: + def extractor_for_bulk_encoded_details(after: str) -> tuple[str, datetime]: result_after_list = after.split(",") after_email_id = result_after_list[0] after_start_time = dateutil.parser.parse(result_after_list[1]) @@ -77,7 +67,5 @@ def extractor_for_bulk_encoded_details(after: str) -> Tuple[str, datetime]: @staticmethod def compressor_for_bulk_encoded_details(last_email_id, last_update_time): - result_after_encoded = base64.urlsafe_b64encode( - f"{last_email_id},{last_update_time}".encode("utf-8") - ) + result_after_encoded = base64.urlsafe_b64encode(f"{last_email_id},{last_update_time}".encode()) return result_after_encoded.decode() diff --git a/ctms/schemas/common.py b/ctms/schemas/common.py index 15bd2ca7..3ed9ad1d 100644 --- a/ctms/schemas/common.py +++ b/ctms/schemas/common.py @@ -5,9 +5,7 @@ http_url_adapter = TypeAdapter(AnyUrl) -AnyUrlString = Annotated[ - str, BeforeValidator(lambda value: str(http_url_adapter.validate_python(value))) -] +AnyUrlString = Annotated[str, BeforeValidator(lambda value: str(http_url_adapter.validate_python(value)))] ZeroOffsetDatetime = Annotated[datetime, PlainSerializer(lambda dt: dt.isoformat())] diff --git a/ctms/schemas/contact.py b/ctms/schemas/contact.py index bfe4c493..3c7fa7c2 100644 --- a/ctms/schemas/contact.py +++ b/ctms/schemas/contact.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal from uuid import UUID from pydantic import BaseModel, Field, field_validator, model_validator @@ -39,11 +39,11 @@ class ContactSchema(ComparableBase): """A complete contact.""" - amo: Optional[AddOnsSchema] = None + amo: AddOnsSchema | None = None email: EmailSchema - fxa: Optional[FirefoxAccountsSchema] = None - mofo: Optional[MozillaFoundationSchema] = None - newsletters: List[NewsletterTableSchema] = Field( + fxa: FirefoxAccountsSchema | None = None + mofo: MozillaFoundationSchema | None = None + newsletters: list[NewsletterTableSchema] = Field( default_factory=list, description="List of newsletters for which the contact is or was subscribed", examples=[ @@ -63,7 +63,7 @@ class ContactSchema(ComparableBase): ] ], ) - waitlists: List[WaitlistTableSchema] = Field( + waitlists: list[WaitlistTableSchema] = Field( default_factory=list, description="List of waitlists for which the contact is or was subscribed", examples=[ @@ -122,11 +122,11 @@ def as_identity_response(self) -> "IdentityResponse": class ContactInBase(ComparableBase): """A contact as provided by callers.""" - amo: Optional[AddOnsInSchema] = None + amo: AddOnsInSchema | None = None email: EmailBase - fxa: Optional[FirefoxAccountsInSchema] = None - mofo: Optional[MozillaFoundationInSchema] = None - newsletters: List[NewsletterInSchema] = Field( + fxa: FirefoxAccountsInSchema | None = None + mofo: MozillaFoundationInSchema | None = None + newsletters: list[NewsletterInSchema] = Field( default_factory=list, examples=[ [ @@ -139,7 +139,7 @@ class ContactInBase(ComparableBase): ] ], ) - waitlists: List[WaitlistInSchema] = Field( + waitlists: list[WaitlistInSchema] = Field( default_factory=list, examples=[ [ @@ -191,29 +191,18 @@ class ContactPatchSchema(ComparableBase): "UNSUBSCRIBE" instead of lists or objects. """ - amo: Optional[Union[Literal["DELETE"], AddOnsInSchema]] = Field( - None, description='Add-ons data to update, or "DELETE" to reset.' - ) - email: Optional[EmailPatchSchema] = None - fxa: Optional[Union[Literal["DELETE"], FirefoxAccountsInSchema]] = Field( - None, description='Firefox Accounts data to update, or "DELETE" to reset.' - ) - mofo: Optional[Union[Literal["DELETE"], MozillaFoundationInSchema]] = Field( - None, description='Mozilla Foundation data to update, or "DELETE" to reset.' - ) - newsletters: Optional[Union[List[NewsletterSchema], Literal["UNSUBSCRIBE"]]] = ( - Field( - None, - description=( - "List of newsletters to add or update, or 'UNSUBSCRIBE' to" - " unsubscribe from all." - ), - examples=[[{"name": "firefox-welcome", "subscribed": False}]], - ) + amo: Literal["DELETE"] | AddOnsInSchema | None = Field(None, description='Add-ons data to update, or "DELETE" to reset.') + email: EmailPatchSchema | None = None + fxa: Literal["DELETE"] | FirefoxAccountsInSchema | None = Field(None, description='Firefox Accounts data to update, or "DELETE" to reset.') + mofo: Literal["DELETE"] | MozillaFoundationInSchema | None = Field(None, description='Mozilla Foundation data to update, or "DELETE" to reset.') + newsletters: list[NewsletterSchema] | Literal["UNSUBSCRIBE"] | None = Field( + None, + description="List of newsletters to add or update, or 'UNSUBSCRIBE' to unsubscribe from all.", + examples=[[{"name": "firefox-welcome", "subscribed": False}]], ) - waitlists: Optional[Union[List[WaitlistInSchema], Literal["UNSUBSCRIBE"]]] = Field( + waitlists: list[WaitlistInSchema] | Literal["UNSUBSCRIBE"] | None = Field( None, - description=("List of waitlists to add or update."), + description="List of waitlists to add or update.", examples=[ [ { @@ -236,8 +225,8 @@ class CTMSResponse(BaseModel): email: EmailSchema fxa: FirefoxAccountsSchema mofo: MozillaFoundationSchema - newsletters: List[NewsletterTimestampedSchema] - waitlists: List[WaitlistTimestampedSchema] + newsletters: list[NewsletterTimestampedSchema] + waitlists: list[WaitlistTimestampedSchema] # Retro-compat fields vpn_waitlist: VpnWaitlistSchema relay_waitlist: RelayWaitlistSchema @@ -279,13 +268,8 @@ def legacy_waitlists(cls, values): # first waitlist is set as the value of `relay_waitlist["geo"]`. This # property is intended for legacy consumers. New consumers should prefer the # `waitlists` property of the contact schema - if ( - waitlist.name.startswith("relay") - and values["relay_waitlist"].geo is None - ): - values["relay_waitlist"] = RelayWaitlistSchema( - geo=waitlist.fields.get("geo") - ) + if waitlist.name.startswith("relay") and values["relay_waitlist"].geo is None: + values["relay_waitlist"] = RelayWaitlistSchema(geo=waitlist.fields.get("geo")) return values @@ -297,9 +281,7 @@ class CTMSSingleResponse(CTMSResponse): Similar to ContactSchema, but groups are required and includes status: OK """ - status: Literal["ok"] = Field( - default="ok", description="Request was successful", examples=["ok"] - ) + status: Literal["ok"] = Field(default="ok", description="Request was successful", examples=["ok"]) class CTMSBulkResponse(BaseModel): @@ -311,9 +293,9 @@ class CTMSBulkResponse(BaseModel): start: datetime end: datetime limit: int - after: Optional[str] = None - next: Optional[Union[AnyUrlString, str]] = None - items: List[CTMSResponse] + after: str | None = None + next: AnyUrlString | str | None = None + items: list[CTMSResponse] class IdentityResponse(BaseModel): @@ -321,10 +303,10 @@ class IdentityResponse(BaseModel): email_id: UUID primary_email: str - basket_token: Optional[UUID] = None - sfdc_id: Optional[str] = None - mofo_contact_id: Optional[str] = None - mofo_email_id: Optional[str] = None - amo_user_id: Optional[str] = None - fxa_id: Optional[str] = None - fxa_primary_email: Optional[str] = None + basket_token: UUID | None = None + sfdc_id: str | None = None + mofo_contact_id: str | None = None + mofo_email_id: str | None = None + amo_user_id: str | None = None + fxa_id: str | None = None + fxa_primary_email: str | None = None diff --git a/ctms/schemas/email.py b/ctms/schemas/email.py index dbc5f561..97a15444 100644 --- a/ctms/schemas/email.py +++ b/ctms/schemas/email.py @@ -1,5 +1,5 @@ -from datetime import datetime, timezone -from typing import Literal, Optional +from datetime import UTC, datetime +from typing import Literal from uuid import UUID from pydantic import UUID4, ConfigDict, Field, field_validator @@ -18,7 +18,7 @@ class EmailBase(ComparableBase): description="Contact email address, Email in Salesforce", examples=["contact@example.com"], ) - basket_token: Optional[UUID] = Field( + basket_token: UUID | None = Field( default=None, description="Basket token, Token__c in Salesforce", examples=["c4a7d759-bb52-457b-896b-90f1d3ef8433"], @@ -28,25 +28,25 @@ class EmailBase(ComparableBase): description="User has clicked a confirmation link", examples=[True], ) - sfdc_id: Optional[str] = Field( + sfdc_id: str | None = Field( default=None, max_length=255, description="Salesforce legacy ID, Id in Salesforce", examples=["001A000023aABcDEFG"], ) - first_name: Optional[str] = Field( + first_name: str | None = Field( default=None, max_length=255, description="First name of contact, FirstName in Salesforce", examples=["Jane"], ) - last_name: Optional[str] = Field( + last_name: str | None = Field( default=None, max_length=255, description="Last name of contact, LastName in Salesforce", examples=["Doe"], ) - mailing_country: Optional[str] = Field( + mailing_country: str | None = Field( default=None, max_length=255, description="Mailing country code, 2 lowercase letters, MailingCountryCode in Salesforce", @@ -56,7 +56,7 @@ class EmailBase(ComparableBase): default="H", description="Email format, H=HTML, T=Plain Text, N and Empty=No selection, Email_Format__c in Salesforce", ) - email_lang: Optional[str] = Field( + email_lang: str | None = Field( default="en", max_length=5, description="Email language code, usually 2 lowercase letters, Email_Language__c in Salesforce", @@ -65,7 +65,7 @@ class EmailBase(ComparableBase): default=False, description="User has opted-out, HasOptedOutOfEmail in Salesforce", ) - unsubscribe_reason: Optional[str] = Field( + unsubscribe_reason: str | None = Field( default=None, description="Reason for unsubscribing, in basket IGNORE_USER_FIELDS, Unsubscribe_Reason__c in Salesforce", ) @@ -77,12 +77,12 @@ class EmailSchema(EmailBase): description=EMAIL_ID_DESCRIPTION, examples=[EMAIL_ID_EXAMPLE], ) - create_timestamp: Optional[ZeroOffsetDatetime] = Field( + create_timestamp: ZeroOffsetDatetime | None = Field( default=None, description="Contact creation date, CreatedDate in Salesforce", examples=["2020-03-28T15:41:00.000+00:00"], ) - update_timestamp: Optional[ZeroOffsetDatetime] = Field( + update_timestamp: ZeroOffsetDatetime | None = Field( default=None, description="Contact last modified date, LastModifiedDate in Salesforce", examples=["2021-01-28T21:26:57.511+00:00"], @@ -92,7 +92,7 @@ class EmailSchema(EmailBase): class EmailInSchema(EmailBase): """Nearly identical to EmailPutSchema but the email_id is not required.""" - email_id: Optional[UUID4] = Field( + email_id: UUID4 | None = Field( default=None, description=EMAIL_ID_DESCRIPTION, examples=[EMAIL_ID_EXAMPLE], @@ -111,7 +111,7 @@ class EmailPutSchema(EmailBase): class EmailPatchSchema(EmailInSchema): """Nearly identical to EmailInSchema but nothing is required.""" - primary_email: Optional[str] = None + primary_email: str | None = None @field_validator("primary_email") @classmethod @@ -123,7 +123,7 @@ def prevent_none(cls, value): class UpdatedEmailPutSchema(EmailPutSchema): update_timestamp: ZeroOffsetDatetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="Contact last modified date, LastModifiedDate in Salesforce", examples=["2021-01-28T21:26:57.511+00:00"], ) diff --git a/ctms/schemas/fxa.py b/ctms/schemas/fxa.py index d662e9c6..b53776ad 100644 --- a/ctms/schemas/fxa.py +++ b/ctms/schemas/fxa.py @@ -1,5 +1,4 @@ -from datetime import datetime, timezone -from typing import Optional +from datetime import UTC, datetime from pydantic import ConfigDict, Field @@ -10,29 +9,29 @@ class FirefoxAccountsBase(ComparableBase): """The Firefox Account schema.""" - fxa_id: Optional[str] = Field( + fxa_id: str | None = Field( default=None, description="Firefox Accounts foreign ID, FxA_Id__c in Salesforce", max_length=50, examples=["6eb6ed6ac3b64259968aa490c6c0b9df"], # pragma: allowlist secret ) - primary_email: Optional[str] = Field( + primary_email: str | None = Field( default=None, description="FxA Email, can be foreign ID, FxA_Primary_Email__c in Salesforce", examples=["my-fxa-acct@example.com"], ) - created_date: Optional[str] = Field( + created_date: str | None = Field( default=None, description="Source is unix timestamp, FxA_Created_Date__c in Salesforce", examples=["2021-01-29T18:43:49.082375+00:00"], ) - lang: Optional[str] = Field( + lang: str | None = Field( default=None, max_length=255, description="FxA Locale (from browser Accept-Language header), FxA_Language__c in Salesforce", examples=["en,en-US"], ) - first_service: Optional[str] = Field( + first_service: str | None = Field( default=None, max_length=50, description="First service that an FxA user used, FirstService__c in Salesforce", @@ -40,10 +39,7 @@ class FirefoxAccountsBase(ComparableBase): ) account_deleted: bool = Field( default=False, - description=( - "Set to True when FxA account deleted or dupe," - " FxA_Account_Deleted__c in Salesforce" - ), + description="Set to True when FxA account deleted or dupe, FxA_Account_Deleted__c in Salesforce", ) model_config = ConfigDict(from_attributes=True) @@ -55,7 +51,7 @@ class FirefoxAccountsBase(ComparableBase): class UpdatedFirefoxAccountsInSchema(FirefoxAccountsInSchema): update_timestamp: ZeroOffsetDatetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="FXA data update timestamp", examples=["2021-01-28T21:26:57.511+00:00"], ) diff --git a/ctms/schemas/mofo.py b/ctms/schemas/mofo.py index 88a6c5ca..b75ec4d6 100644 --- a/ctms/schemas/mofo.py +++ b/ctms/schemas/mofo.py @@ -1,24 +1,20 @@ -from typing import Optional - from pydantic import ConfigDict, Field from .base import ComparableBase class MozillaFoundationBase(ComparableBase): - mofo_email_id: Optional[str] = Field( + mofo_email_id: str | None = Field( default=None, max_length=255, description="Foriegn key to email in MoFo contact database", ) - mofo_contact_id: Optional[str] = Field( + mofo_contact_id: str | None = Field( default=None, max_length=255, description="Foriegn key to contact in MoFo contact database", ) - mofo_relevant: bool = Field( - default=False, description="Mozilla Foundation is tracking this email" - ) + mofo_relevant: bool = Field(default=False, description="Mozilla Foundation is tracking this email") model_config = ConfigDict(from_attributes=True) diff --git a/ctms/schemas/newsletter.py b/ctms/schemas/newsletter.py index acd20a56..17078139 100644 --- a/ctms/schemas/newsletter.py +++ b/ctms/schemas/newsletter.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Literal, Optional +from typing import Literal from pydantic import UUID4, ConfigDict, Field @@ -6,9 +6,6 @@ from .common import AnyUrlString, ZeroOffsetDatetime from .email import EMAIL_ID_DESCRIPTION, EMAIL_ID_EXAMPLE -if TYPE_CHECKING: - from ctms.models import Newsletter - class NewsletterBase(ComparableBase): """The newsletter subscriptions schema.""" @@ -17,26 +14,20 @@ class NewsletterBase(ComparableBase): description="Basket slug for the newsletter", examples=["mozilla-welcome"], ) - subscribed: bool = Field( - default=True, description="True if subscribed, False when formerly subscribed" - ) - format: Literal["H", "T"] = Field( - default="H", description="Newsletter format, H=HTML, T=Plain Text" - ) - lang: Optional[str] = Field( + subscribed: bool = Field(default=True, description="True if subscribed, False when formerly subscribed") + format: Literal["H", "T"] = Field(default="H", description="Newsletter format, H=HTML, T=Plain Text") + lang: str | None = Field( default="en", min_length=2, max_length=5, description="Newsletter language code, usually 2 lowercase letters", ) - source: Optional[AnyUrlString] = Field( + source: AnyUrlString | None = Field( default=None, description="Source URL of subscription", examples=["https://www.mozilla.org/en-US/"], ) - unsub_reason: Optional[str] = Field( - default=None, description="Reason for unsubscribing" - ) + unsub_reason: str | None = Field(default=None, description="Reason for unsubscribing") def __lt__(self, other): return self.name < other.name diff --git a/ctms/schemas/waitlist.py b/ctms/schemas/waitlist.py index 5021523d..19f4f2e8 100644 --- a/ctms/schemas/waitlist.py +++ b/ctms/schemas/waitlist.py @@ -1,16 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union - from pydantic import UUID4, ConfigDict, Field, model_validator from .base import ComparableBase from .common import AnyUrlString, ZeroOffsetDatetime from .email import EMAIL_ID_DESCRIPTION, EMAIL_ID_EXAMPLE -if TYPE_CHECKING: - from .contact import ContactInBase, ContactPatchSchema - class WaitlistBase(ComparableBase): """ @@ -25,20 +20,14 @@ class WaitlistBase(ComparableBase): description="Basket slug for the waitlist", examples=["new-product"], ) - source: Optional[AnyUrlString] = Field( + source: AnyUrlString | None = Field( default=None, description="Source URL of subscription", examples=["https://www.mozilla.org/en-US/"], ) - fields: dict = Field( - default={}, description="Additional fields", examples=['{"platform": "linux"}'] - ) - subscribed: bool = Field( - default=True, description="True to subscribe, False to unsubscribe" - ) - unsub_reason: Optional[str] = Field( - default=None, description="Reason for unsubscribing" - ) + fields: dict = Field(default={}, description="Additional fields", examples=['{"platform": "linux"}']) + subscribed: bool = Field(default=True, description="True to subscribe, False to unsubscribe") + unsub_reason: str | None = Field(default=None, description="Reason for unsubscribing") def __lt__(self, other): return self.name < other.name @@ -52,7 +41,7 @@ def check_fields(self): if self.name == "relay": class RelayFieldsSchema(ComparableBase): - geo: Optional[str] = CountryField() + geo: str | None = CountryField() model_config = ConfigDict(extra="forbid") RelayFieldsSchema(**self.fields) @@ -60,8 +49,8 @@ class RelayFieldsSchema(ComparableBase): elif self.name == "vpn": class VPNFieldsSchema(ComparableBase): - geo: Optional[str] = CountryField() - platform: Optional[str] = PlatformField() + geo: str | None = CountryField() + platform: str | None = PlatformField() model_config = ConfigDict(extra="forbid") VPNFieldsSchema(**self.fields) @@ -73,8 +62,8 @@ class VPNFieldsSchema(ComparableBase): # This should allow us to onboard most waitlists without specific # code change and service redeployment. class DefaultFieldsSchema(ComparableBase): - geo: Optional[str] = CountryField() - platform: Optional[str] = PlatformField() + geo: str | None = CountryField() + platform: str | None = PlatformField() DefaultFieldsSchema(**self.fields) @@ -132,7 +121,7 @@ class RelayWaitlistSchema(ComparableBase): The Mozilla Relay Waitlist schema for the read-only `relay_waitlist` field. """ - geo: Optional[str] = Field( + geo: str | None = Field( default=None, max_length=100, description="Relay waitlist country", @@ -146,19 +135,16 @@ class VpnWaitlistSchema(ComparableBase): The Mozilla VPN Waitlist schema for the read-only `vpn_waitlist` field """ - geo: Optional[str] = Field( + geo: str | None = Field( default=None, max_length=100, description="VPN waitlist country, FPN_Waitlist_Geo__c in Salesforce", examples=["fr"], ) - platform: Optional[str] = Field( + platform: str | None = Field( default=None, max_length=100, - description=( - "VPN waitlist platforms as comma-separated list," - " FPN_Waitlist_Platform__c in Salesforce" - ), + description="VPN waitlist platforms as comma-separated list, FPN_Waitlist_Platform__c in Salesforce", examples=["ios,mac"], ) model_config = ConfigDict(from_attributes=True) diff --git a/poetry.lock b/poetry.lock index a2824fd3..77f0eca6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. [[package]] name = "alembic" @@ -1838,30 +1838,30 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.8.6" +version = "0.9.3" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.8.6-py3-none-linux_armv6l.whl", hash = "sha256:defed167955d42c68b407e8f2e6f56ba52520e790aba4ca707a9c88619e580e3"}, - {file = "ruff-0.8.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:54799ca3d67ae5e0b7a7ac234baa657a9c1784b48ec954a094da7c206e0365b1"}, - {file = "ruff-0.8.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e88b8f6d901477c41559ba540beeb5a671e14cd29ebd5683903572f4b40a9807"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0509e8da430228236a18a677fcdb0c1f102dd26d5520f71f79b094963322ed25"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:91a7ddb221779871cf226100e677b5ea38c2d54e9e2c8ed847450ebbdf99b32d"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:248b1fb3f739d01d528cc50b35ee9c4812aa58cc5935998e776bf8ed5b251e75"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:bc3c083c50390cf69e7e1b5a5a7303898966be973664ec0c4a4acea82c1d4315"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52d587092ab8df308635762386f45f4638badb0866355b2b86760f6d3c076188"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61323159cf21bc3897674e5adb27cd9e7700bab6b84de40d7be28c3d46dc67cf"}, - {file = "ruff-0.8.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ae4478b1471fc0c44ed52a6fb787e641a2ac58b1c1f91763bafbc2faddc5117"}, - {file = "ruff-0.8.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0c000a471d519b3e6cfc9c6680025d923b4ca140ce3e4612d1a2ef58e11f11fe"}, - {file = "ruff-0.8.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9257aa841e9e8d9b727423086f0fa9a86b6b420fbf4bf9e1465d1250ce8e4d8d"}, - {file = "ruff-0.8.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45a56f61b24682f6f6709636949ae8cc82ae229d8d773b4c76c09ec83964a95a"}, - {file = "ruff-0.8.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:496dd38a53aa173481a7d8866bcd6451bd934d06976a2505028a50583e001b76"}, - {file = "ruff-0.8.6-py3-none-win32.whl", hash = "sha256:e169ea1b9eae61c99b257dc83b9ee6c76f89042752cb2d83486a7d6e48e8f764"}, - {file = "ruff-0.8.6-py3-none-win_amd64.whl", hash = "sha256:f1d70bef3d16fdc897ee290d7d20da3cbe4e26349f62e8a0274e7a3f4ce7a905"}, - {file = "ruff-0.8.6-py3-none-win_arm64.whl", hash = "sha256:7d7fc2377a04b6e04ffe588caad613d0c460eb2ecba4c0ccbbfe2bc973cbc162"}, - {file = "ruff-0.8.6.tar.gz", hash = "sha256:dcad24b81b62650b0eb8814f576fc65cfee8674772a6e24c9b747911801eeaa5"}, + {file = "ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624"}, + {file = "ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c"}, + {file = "ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4"}, + {file = "ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6"}, + {file = "ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730"}, + {file = "ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2"}, + {file = "ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519"}, + {file = "ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b"}, + {file = "ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c"}, + {file = "ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4"}, + {file = "ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b"}, + {file = "ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a"}, ] [[package]] @@ -2425,4 +2425,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.12,<4" -content-hash = "0b3f8532af776d2d5411f0fe7832e44a16418628a47a4a8446788187aacef9c0" +content-hash = "d0922204796fd2f481f3d4bfa0cf0b8fd4ca11b3226692e5868e974ce9e92080" diff --git a/pyproject.toml b/pyproject.toml index 07df5d13..caa45de5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ mypy = "^1.14.1" pre-commit = ">=4.0.1" pytest = ">=8.3.4" pytest-factoryboy = ">=2.7.0" -ruff = "^0.8.6" +ruff = ">=0.9.3" SQLAlchemy-Utils = ">=0.41.2" types-python-dateutil = ">=2.9.0" types-requests = ">=2.32.0" @@ -65,10 +65,29 @@ markers = [ [tool.ruff] target-version = "py312" +line-length = 150 +extend-exclude = ["migrations"] [tool.ruff.lint] -select = [ "PL", "I"] -ignore = [ "PLR2004", "PLR0913" ] +select = [ + "A", # flake8-builtin errors + "B", # bugbear errors + "C4", # flake8-comprehensions errors + "E", # pycodestyle errors + "F", # pyflakes errors + "FAST", # FastAPI + "I", # import sorting + "PL", # pylint errors + "Q", # flake8-quotes errors + "UP", # py-upgrade + "W", # pycodestyle warnings +] +ignore = [ + "A005", # stdlib module shadowing - platform and email + "B008", # function call in default arguments - used for `Depends` in argument defaults. + "PLR2004", # magic value comparison + "PLR0913", # too many arguments +] [tool.coverage.run] omit = [ diff --git a/suppression-list/csv2optout.py b/suppression-list/csv2optout.py index da16aba6..b71cf2a3 100644 --- a/suppression-list/csv2optout.py +++ b/suppression-list/csv2optout.py @@ -3,7 +3,7 @@ import os import re import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path import click @@ -118,9 +118,9 @@ """ -def chunks(l, n): - for i in range(0, len(l), n): - yield l[i : i + n] +def chunks(lst, n): + for i in range(0, len(lst), n): + yield lst[i : i + n] def writefile(path, content): @@ -131,18 +131,12 @@ def writefile(path, content): @click.command() @click.argument("csv_path", type=click.Path(exists=True)) -@click.option( - "--check-input-rows", default=1000, help="Number of rows to check from input CSV." -) +@click.option("--check-input-rows", default=1000, help="Number of rows to check from input CSV.") @click.option("--batch-size", default=10000, help="Number of updates per commit.") @click.option("--files-count", default=3, help="Number of SQL files") @click.option("--sleep-seconds", default=0.1, help="Wait between batches") -@click.option( - "--schedule-sync", default=False, help="Mark update emails as pending sync" -) -@click.option( - "--csv-path-server", default=".", help="Absolute path where to load the CSV from" -) +@click.option("--schedule-sync", default=False, help="Mark update emails as pending sync") +@click.option("--csv-path-server", default=".", help="Absolute path where to load the CSV from") @click.option( "--table-suffix", default=None, @@ -180,8 +174,8 @@ def main( email, date, reason = row assert "@" in email assert re.match(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2} (AM|PM)", date) - except (AssertionError, ValueError): - raise ValueError(f"Line '{row}' does not look right") + except (AssertionError, ValueError) as err: + raise ValueError(f"Line '{row}' does not look right") from err batch_count = 1 + csv_rows_count // batch_size chunk_size = 1 + batch_count // files_count @@ -193,21 +187,21 @@ def main( # # Prepare SQL files # - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) tmp_suffix = table_suffix or now.strftime("%Y%m%dT%H%M") join_batches = [] update_batches = [] for i in range(batch_count): start_idx = i * batch_size end_idx = (i + 1) * batch_size - params = dict( - batch=i + 1, - batch_count=batch_count, - start_idx=start_idx, - end_idx=end_idx, - tmp_suffix=tmp_suffix, - sleep_seconds=sleep_seconds, - ) + params = { + "batch": i + 1, + "batch_count": batch_count, + "start_idx": start_idx, + "end_idx": end_idx, + "tmp_suffix": tmp_suffix, + "sleep_seconds": sleep_seconds, + } join_batches.append(SQL_JOIN_BATCH.format(**params)) update_batches.append( SQL_UPDATE_BATCH.format( @@ -233,16 +227,14 @@ def main( file_count = len(chunked) for i, batch in enumerate(chunked): writefile( - f"{csv_filename}.{i+1}.apply.sql", - "".join(batch) + f"CALL raise_notice('File {i+1}/{file_count} done.');", + f"{csv_filename}.{i + 1}.apply.sql", + "".join(batch) + f"CALL raise_notice('File {i + 1}/{file_count} done.');", ) - logger.info( - f"Produced {file_count} files, with {chunk_size} commits ({chunk_size * batch_size} updates)." - ) + logger.info(f"Produced {file_count} files, with {chunk_size} commits ({chunk_size * batch_size} updates).") writefile( - f"{csv_filename}.{file_count+1}.post.sql", + f"{csv_filename}.{file_count + 1}.post.sql", SQL_COMMANDS_POST.format( tmp_suffix=tmp_suffix, ), diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index 0650744f..3c15c1f3 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -1 +1,2 @@ +# ruff: noqa: F401 -- Allow unused imports from . import models diff --git a/tests/factories/models.py b/tests/factories/models.py index 5dfce618..d69a6e08 100644 --- a/tests/factories/models.py +++ b/tests/factories/models.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from uuid import uuid4 import factory @@ -17,10 +17,7 @@ class RelatedFactoryVariableList(factory.RelatedFactoryList): def call(self, instance, step, context): size = context.extra.pop("size", self.size) assert isinstance(size, int) - return [ - super(factory.RelatedFactoryList, self).call(instance, step, context) - for _ in range(size) - ] + return [super(factory.RelatedFactoryList, self).call(instance, step, context) for _ in range(size)] class BaseSQLAlchemyModelFactory(SQLAlchemyModelFactory): @@ -116,7 +113,7 @@ class Meta: double_opt_in = False has_opted_out_of_email = False - create_timestamp = factory.LazyFunction(lambda: datetime.now(timezone.utc)) + create_timestamp = factory.LazyFunction(lambda: datetime.now(UTC)) update_timestamp = factory.LazyAttribute(lambda obj: obj.create_timestamp) @factory.post_generation diff --git a/tests/integration/test_basket_waitlist_subscription.py b/tests/integration/test_basket_waitlist_subscription.py index ce207ac9..ca9f0eef 100644 --- a/tests/integration/test_basket_waitlist_subscription.py +++ b/tests/integration/test_basket_waitlist_subscription.py @@ -18,9 +18,7 @@ class Settings(BaseSettings): # We initialize CTMS api client id/secret in `ctms-db-init.sql` ctms_client_id: str = "id_integration-test" ctms_client_secret: str - model_config = SettingsConfigDict( - env_file=os.path.join(TEST_FOLDER, "basket.env"), extra="ignore" - ) + model_config = SettingsConfigDict(env_file=os.path.join(TEST_FOLDER, "basket.env"), extra="ignore") settings = Settings() @@ -35,9 +33,7 @@ class Settings(BaseSettings): @pytest.fixture(scope="session", autouse=True) def adjust_backoff_logger(pytestconfig): # Detect whether pytest was run using `-v` or `-vv` and logging. - backoff_logger.setLevel( - logging.INFO if pytestconfig.getoption("verbose") > 0 else logging.ERROR - ) + backoff_logger.setLevel(logging.INFO if pytestconfig.getoption("verbose") > 0 else logging.ERROR) @pytest.fixture(scope="session") @@ -58,7 +54,7 @@ def ctms_headers(): resp.raise_for_status() token = resp.json() return { - "Authorization": f'{token["token_type"]} {token["access_token"]}', + "Authorization": f"{token['token_type']} {token['access_token']}", } @@ -117,9 +113,7 @@ def test_vpn_waitlist(ctms_headers): email = f"integration-test-{uuid4()}@restmail.net" vpn_waitlist_slug = "guardian-vpn-waitlist" - basket_subscribe( - email, vpn_waitlist_slug, fpn_country="us", fpn_platform="ios,android" - ) + basket_subscribe(email, vpn_waitlist_slug, fpn_country="us", fpn_platform="ios,android") # 2. Basket should have set the `vpn_waitlist` field/data. # Wait for the worker to have processed the request. diff --git a/tests/unit/bin/test_client_credentials.py b/tests/unit/bin/test_client_credentials.py index ec21be17..0274eed6 100644 --- a/tests/unit/bin/test_client_credentials.py +++ b/tests/unit/bin/test_client_credentials.py @@ -6,9 +6,7 @@ @pytest.fixture def existing_client(dbsession): - client = ApiClient( - client_id="id_existing", email="existing@example.com", hashed_secret="password" - ) + client = ApiClient(client_id="id_existing", email="existing@example.com", hashed_secret="password") dbsession.add(client) dbsession.flush() return client @@ -38,9 +36,7 @@ def test_create_explicit_id(dbsession, settings): def test_create_disabled(dbsession, settings): """New client credentials can be generated as disabled.""" - ret = main( - dbsession, settings, ["test2", "--email", "test@example.com", "--disable"] - ) + ret = main(dbsession, settings, ["test2", "--email", "test@example.com", "--disable"]) assert ret == 0 client = dbsession.query(ApiClient).one() @@ -57,9 +53,7 @@ def test_create_email_required(dbsession, settings): assert dbsession.query(ApiClient).first() is None -@pytest.mark.parametrize( - "client_id", ("service.mozilla.com", "1-800-Contacts", "under_score.js") -) +@pytest.mark.parametrize("client_id", ("service.mozilla.com", "1-800-Contacts", "under_score.js")) def test_create_valid_client_id(dbsession, settings, client_id): """Some punctuation is allowed.""" ret = main(dbsession, settings, [client_id, "--email", "test@example.com"]) @@ -113,9 +107,7 @@ def test_update_enable(dbsession, settings, existing_client): def test_update_enable_and_disable_fails(dbsession, settings, existing_client): """Picking enable and disable is an error.""" - ret = main( - dbsession, settings, [existing_client.client_id, "--disable", "--enable"] - ) + ret = main(dbsession, settings, [existing_client.client_id, "--disable", "--enable"]) assert ret == 1 client = dbsession.query(ApiClient).one() assert client.enabled diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5e54479d..c88c2b82 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -2,9 +2,9 @@ import logging import os.path -from datetime import datetime, timezone +from collections.abc import Callable +from datetime import UTC, datetime from time import mktime -from typing import Callable, Optional from urllib.parse import urlparse from uuid import UUID @@ -62,9 +62,9 @@ def _gather_examples(schema_class) -> dict[str, str]: return examples -def unix_timestamp(the_time: Optional[datetime] = None) -> int: +def unix_timestamp(the_time: datetime | None = None) -> int: """Create a UNIX timestamp from a datetime or now""" - the_time = the_time or datetime.now(tz=timezone.utc) + the_time = the_time or datetime.now(tz=UTC) return int(mktime(the_time.timetuple())) @@ -178,23 +178,15 @@ def create_full_contact(db, contact: ContactSchema): specified_newsletters_by_name = {nl.name: nl for nl in contact.newsletters} if specified_newsletters_by_name: for newsletter_in_db in get_newsletters_by_email_id(db, contact.email.email_id): - newsletter_in_db.create_timestamp = specified_newsletters_by_name[ - newsletter_in_db.name - ].create_timestamp - newsletter_in_db.update_timestamp = specified_newsletters_by_name[ - newsletter_in_db.name - ].update_timestamp + newsletter_in_db.create_timestamp = specified_newsletters_by_name[newsletter_in_db.name].create_timestamp + newsletter_in_db.update_timestamp = specified_newsletters_by_name[newsletter_in_db.name].update_timestamp db.add(newsletter_in_db) specified_waitlists_by_name = {wl.name: wl for wl in contact.waitlists} if specified_waitlists_by_name: for waitlist_in_db in get_waitlists_by_email_id(db, contact.email.email_id): - waitlist_in_db.create_timestamp = specified_waitlists_by_name[ - waitlist_in_db.name - ].create_timestamp - waitlist_in_db.update_timestamp = specified_waitlists_by_name[ - waitlist_in_db.name - ].update_timestamp + waitlist_in_db.create_timestamp = specified_waitlists_by_name[waitlist_in_db.name].create_timestamp + waitlist_in_db.update_timestamp = specified_waitlists_by_name[waitlist_in_db.name].update_timestamp db.add(waitlist_in_db) db.commit() @@ -390,18 +382,9 @@ def example_contact_data() -> ContactSchema: return ContactSchema( amo=schemas.AddOnsSchema(**_gather_examples(schemas.AddOnsSchema)), email=schemas.EmailSchema(**_gather_examples(schemas.EmailSchema)), - fxa=schemas.FirefoxAccountsSchema( - **_gather_examples(schemas.FirefoxAccountsSchema) - ), - newsletters=ContactSchema.model_json_schema()["properties"]["newsletters"][ - "examples" - ][0], - waitlists=[ - schemas.WaitlistTableSchema(**example) - for example in ContactSchema.model_json_schema()["properties"]["waitlists"][ - "examples" - ][0] - ], + fxa=schemas.FirefoxAccountsSchema(**_gather_examples(schemas.FirefoxAccountsSchema)), + newsletters=ContactSchema.model_json_schema()["properties"]["newsletters"]["examples"][0], + waitlists=[schemas.WaitlistTableSchema(**example) for example in ContactSchema.model_json_schema()["properties"]["waitlists"]["examples"][0]], ) @@ -485,9 +468,7 @@ def client(anon_client): """A test client that passed a valid OAuth2 token.""" def test_api_client(): - return ApiClientSchema( - client_id="test_client", email="test_client@example.com", enabled=True - ) + return ApiClientSchema(client_id="test_client", email="test_client@example.com", enabled=True) app.dependency_overrides[get_api_client] = test_api_client yield anon_client @@ -529,9 +510,7 @@ def metrics(setup_metrics): @pytest.fixture def client_id_and_secret(dbsession): """Return valid OAuth2 client_id and client_secret.""" - api_client = ApiClientSchema( - client_id="id_db_api_client", email="db_api_client@example.com", enabled=True - ) + api_client = ApiClientSchema(client_id="id_db_api_client", email="db_api_client@example.com", enabled=True) secret = "secret_what_a_weird_random_string" # pragma: allowlist secret create_api_client(dbsession, api_client, secret) dbsession.flush() @@ -549,7 +528,7 @@ def _add( code: int = 201, stored_contacts: int = 1, check_redirect: bool = True, - query_fields: Optional[dict] = None, + query_fields: dict | None = None, check_written: bool = True, ): if query_fields is None: @@ -575,23 +554,15 @@ def _check_written(field, getter, result_list=False): if sample.model_dump().get(field) and code in {200, 201}: if field in fields_not_written: if result_list: - assert ( - results == [] - ), f"{email_id} has field `{field}` but it is _default_ and it should _not_ have been written to db" + assert results == [], f"{email_id} has field `{field}` but it is _default_ and it should _not_ have been written to db" else: - assert ( - results is None - ), f"{email_id} has field `{field}` but it is _default_ and it should _not_ have been written to db" + assert results is None, f"{email_id} has field `{field}` but it is _default_ and it should _not_ have been written to db" else: assert results, f"{email_id} has field `{field}` and it should have been written to db" elif result_list: - assert ( - results == [] - ), f"{email_id} does not have field `{field}` and it should _not_ have been written to db" + assert results == [], f"{email_id} does not have field `{field}` and it should _not_ have been written to db" else: - assert ( - results is None - ), f"{email_id} does not have field `{field}` and it should _not_ have been written to db" + assert results is None, f"{email_id} does not have field `{field}` and it should _not_ have been written to db" if check_written: _check_written("amo", get_amo_by_email_id) @@ -625,10 +596,10 @@ def _add( modifier: Callable[[ContactSchema], ContactSchema] = lambda x: x, code: int = 201, stored_contacts: int = 1, - query_fields: Optional[dict] = None, + query_fields: dict | None = None, check_written: bool = True, - record: Optional[ContactSchema] = None, - new_default_fields: Optional[set] = None, + record: ContactSchema | None = None, + new_default_fields: set | None = None, ): if record: contact = record @@ -639,9 +610,7 @@ def _add( new_default_fields = new_default_fields or set() sample = contact.model_copy(deep=True) sample = modifier(sample) - resp = client.put( - f"/ctms/{sample.email.email_id}", content=sample.model_dump_json() - ) + resp = client.put(f"/ctms/{sample.email.email_id}", content=sample.model_dump_json()) assert resp.status_code == code, resp.text saved = get_contacts_by_any_id(dbsession, **query_fields) assert len(saved) == stored_contacts @@ -657,16 +626,15 @@ def _check_written(field, getter): results = getter(dbsession, written_id) if sample.model_dump().get(field) and code in {200, 201}: if field in fields_not_written or field in new_default_fields: - assert ( - results is None - or (isinstance(results, list) and len(results) == 0) - ), f"{sample_email_id} has field `{field}` but it is _default_ and it should _not_ have been written to db" + assert results is None or (isinstance(results, list) and len(results) == 0), ( + f"{sample_email_id} has field `{field}` but it is _default_ and it should _not_ have been written to db" + ) else: assert results, f"{sample_email_id} has field `{field}` and it should have been written to db" else: - assert ( - results is None or (isinstance(results, list) and len(results) == 0) - ), f"{sample_email_id} does not have field `{field}` and it should _not_ have been written to db" + assert results is None or (isinstance(results, list) and len(results) == 0), ( + f"{sample_email_id} does not have field `{field}` and it should _not_ have been written to db" + ) if check_written: _check_written("amo", get_amo_by_email_id) diff --git a/tests/unit/routers/contacts/test_api.py b/tests/unit/routers/contacts/test_api.py index 0bd23e85..50dcab12 100644 --- a/tests/unit/routers/contacts/test_api.py +++ b/tests/unit/routers/contacts/test_api.py @@ -1,6 +1,5 @@ """Unit tests for cross-API functionality""" -from typing import Optional, Set from uuid import uuid4 import pytest @@ -97,12 +96,8 @@ def _unsubscribe_newsletter(contact): def _subscribe_newsletters_and_change(contact): if contact.newsletters: contact.newsletters[-1].subscribed = not contact.newsletters[-1].subscribed - contact.newsletters.append( - NewsletterInSchema(name="a-newsletter", subscribed=False) - ) - contact.newsletters.append( - NewsletterInSchema(name="another-newsletter", subscribed=True) - ) + contact.newsletters.append(NewsletterInSchema(name="a-newsletter", subscribed=False)) + contact.newsletters.append(NewsletterInSchema(name="another-newsletter", subscribed=True)) def _subscribe_waitlists_and_change(contact): @@ -141,7 +136,7 @@ def _compare_written_contacts( sample, email_id, ids_should_be_identical: bool = True, - new_default_fields: Optional[set] = None, + new_default_fields: set | None = None, ): fields_not_written = new_default_fields or set() @@ -159,7 +154,7 @@ def _compare_written_contacts( assert saved_contact.idempotent_equal(sample) -def find_default_fields(contact: ContactSchema) -> Set[str]: +def find_default_fields(contact: ContactSchema) -> set[str]: """Return names of fields that contain default values only""" default_fields = set() if hasattr(contact, "amo") and contact.amo and contact.amo.is_default(): @@ -199,20 +194,12 @@ def test_post_get_put(client, post_contact, put_contact, update_fetched): # `relay_waitlist` as input. # If we don't strip these two fields before turning the data into # a `ContactInSchema`, they will create waitlist objects. - without_alias_fields = { - k: v - for k, v in resp.json().items() - if k not in ("vpn_waitlist", "relay_waitlist") - } + without_alias_fields = {k: v for k, v in resp.json().items() if k not in ("vpn_waitlist", "relay_waitlist")} fetched = ContactInSchema(**without_alias_fields) update_fetched(fetched) new_default_fields = find_default_fields(fetched) # We set new_default_fields here because the returned response above # _includes_ defaults for many fields and we want to not write # them when the record is PUT again - saved_contacts, sample, email_id = put_contact( - record=fetched, new_default_fields=new_default_fields - ) - _compare_written_contacts( - saved_contacts[0], sample, email_id, new_default_fields=new_default_fields - ) + saved_contacts, sample, email_id = put_contact(record=fetched, new_default_fields=new_default_fields) + _compare_written_contacts(saved_contacts[0], sample, email_id, new_default_fields=new_default_fields) diff --git a/tests/unit/routers/contacts/test_api_get.py b/tests/unit/routers/contacts/test_api_get.py index fbe332d7..0e511341 100644 --- a/tests/unit/routers/contacts/test_api_get.py +++ b/tests/unit/routers/contacts/test_api_get.py @@ -190,9 +190,7 @@ def test_get_ctms_not_found(client, dbsession): @pytest.mark.parametrize("waitlist_name", ["vpn", "relay"]) -def test_get_ctms_without_geo_in_waitlist( - waitlist_name, client, dbsession, waitlist_factory -): +def test_get_ctms_without_geo_in_waitlist(waitlist_name, client, dbsession, waitlist_factory): existing_waitlist = waitlist_factory(name=waitlist_name, fields={}) dbsession.flush() email_id = existing_waitlist.email.email_id diff --git a/tests/unit/routers/contacts/test_api_patch.py b/tests/unit/routers/contacts/test_api_patch.py index 8c8156c0..90c45a8c 100644 --- a/tests/unit/routers/contacts/test_api_patch.py +++ b/tests/unit/routers/contacts/test_api_patch.py @@ -63,9 +63,7 @@ def swap_bool(existing): @pytest.mark.parametrize("group_name,key,value", (patch_single_value_params)) -def test_patch_one_new_value_mostly_empty( - client, email_factory, group_name, key, value -): +def test_patch_one_new_value_mostly_empty(client, email_factory, group_name, key, value): """PATCH can update a single value.""" email = email_factory() contact = ContactSchema.from_email(email) @@ -182,9 +180,7 @@ def test_patch_to_default(client, email_factory, group_name, key): with_amo=True, ) - expected = jsonable_encoder( - CTMSResponse(**ContactSchema.from_email(email).model_dump()) - ) + expected = jsonable_encoder(CTMSResponse(**ContactSchema.from_email(email).model_dump())) existing_value = expected[group_name][key] # Load the default value from the schema @@ -203,9 +199,7 @@ def test_patch_to_default(client, email_factory, group_name, key): expected[group_name][key] = default_value assert existing_value != default_value - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert actual["status"] == "ok" @@ -219,9 +213,7 @@ def test_patch_cannot_set_timestamps(client, email_factory): """PATCH can not set timestamps directly.""" email = email_factory(with_amo=True) - expected = jsonable_encoder( - CTMSResponse(**ContactSchema.from_email(email).model_dump()) - ) + expected = jsonable_encoder(CTMSResponse(**ContactSchema.from_email(email).model_dump())) new_ts = datetime.now(tz=UTC).isoformat() assert expected["amo"]["create_timestamp"] == email.amo.create_timestamp.isoformat() assert expected["amo"]["create_timestamp"] != new_ts @@ -238,9 +230,7 @@ def test_patch_cannot_set_timestamps(client, email_factory): "update_timestamp": new_ts, }, } - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert actual["status"] == "ok" @@ -313,9 +303,7 @@ def test_patch_error_on_id_conflict(client, dbsession, group_name, key, email_fa mofo_email_id=str(uuid4()), mofo_contact_id=str(uuid4()), ), - fxa=FirefoxAccountsInSchema( - fxa_id="1337", primary_email="fxa-conflict@example.com" - ), + fxa=FirefoxAccountsInSchema(fxa_id="1337", primary_email="fxa-conflict@example.com"), ) create_full_contact(dbsession, conflicting_data) @@ -333,12 +321,7 @@ def test_patch_error_on_id_conflict(client, dbsession, group_name, key, email_fa email_id = existing_contact.email.email_id resp = client.patch(f"/ctms/{email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 409 - assert resp.json() == { - "detail": ( - "Contact with primary_email, basket_token, mofo_email_id, or fxa_id" - " already exists" - ) - } + assert resp.json() == {"detail": ("Contact with primary_email, basket_token, mofo_email_id, or fxa_id already exists")} def test_patch_to_subscribe(client, email_factory): @@ -346,9 +329,7 @@ def test_patch_to_subscribe(client, email_factory): email = email_factory(newsletters=1) patch_data = {"newsletters": [{"name": "zzz-newsletter"}]} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert len(actual["newsletters"]) == 2 @@ -369,9 +350,7 @@ def test_patch_to_update_subscription(client, newsletter_factory): existing_newsletter = newsletter_factory() email_id = str(existing_newsletter.email.email_id) - patch_data = { - "newsletters": [{"name": existing_newsletter.name, "format": "H", "lang": "XX"}] - } + patch_data = {"newsletters": [{"name": existing_newsletter.name, "format": "H", "lang": "XX"}]} resp = client.patch(f"/ctms/{email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() @@ -406,9 +385,7 @@ def test_patch_to_unsubscribe(client, email_factory, newsletter_factory): } ] } - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert len(actual["newsletters"]) == len(email.newsletters) @@ -438,9 +415,7 @@ def test_patch_to_unsubscribe_but_not_subscribed(client, email_factory): } ] } - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert len(actual["newsletters"]) == 1 @@ -452,9 +427,7 @@ def test_patch_unsubscribe_all(client, email_factory): email = email_factory(newsletters=2) patch_data = {"newsletters": "UNSUBSCRIBE"} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert len(actual["newsletters"]) == 2 @@ -467,9 +440,7 @@ def test_patch_to_delete_group(client, email_factory, group_name): email = email_factory(with_amo=True, with_fxa=True, with_mofo=True) patch_data = {group_name: "DELETE"} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() defaults = { @@ -487,9 +458,7 @@ def test_patch_to_delete_deleted_group(client, email_factory): assert email.amo is None patch_data = {"mofo": "DELETE"} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() @@ -502,9 +471,7 @@ def test_patch_will_validate_waitlist_fields(client, email_factory): email = email_factory() patch_data = {"waitlists": [{"name": "future-tech", "source": 42}]} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 422 details = resp.json() assert details["detail"][0]["loc"] == [ @@ -521,9 +488,7 @@ def test_patch_to_add_a_waitlist(client, email_factory): email = email_factory() patch_data = {"waitlists": [{"name": "future-tech", "fields": {"geo": "es"}}]} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() [new_waitlist] = actual["waitlists"] @@ -542,9 +507,7 @@ def test_patch_does_not_add_an_unsubscribed_waitlist(client, email_factory): email = email_factory() patch_data = {"waitlists": [{"name": "future-tech", "subscribed": False}]} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert len(actual["waitlists"]) == 0 @@ -555,15 +518,9 @@ def test_patch_to_update_a_waitlist(client, email_factory, waitlist_factory): email = email_factory() waitlist = waitlist_factory(fields={"geo": "fr"}, email=email) - patched_waitlist = ( - WaitlistInSchema.model_validate(waitlist) - .model_copy(update={"fields": {"geo": "ca"}}) - .model_dump() - ) + patched_waitlist = WaitlistInSchema.model_validate(waitlist).model_copy(update={"fields": {"geo": "ca"}}).model_dump() patch_data = {"waitlists": [patched_waitlist]} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert actual["waitlists"][0]["fields"]["geo"] == "ca" @@ -574,16 +531,8 @@ def test_patch_to_remove_a_waitlist(client, email_factory, waitlist_factory): email = email_factory() waitlist_factory(name="bye-bye", email=email) - patch_data = { - "waitlists": [ - WaitlistInSchema( - name="bye-bye", subscribed=False, unsub_reason="Not interested" - ).model_dump() - ] - } - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + patch_data = {"waitlists": [WaitlistInSchema(name="bye-bye", subscribed=False, unsub_reason="Not interested").model_dump()]} + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() [unsubscribed] = actual["waitlists"] @@ -597,9 +546,7 @@ def test_patch_to_remove_all_waitlists(client, email_factory): assert all(wl.subscribed for wl in email.waitlists) patch_data = {"waitlists": "UNSUBSCRIBE"} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() @@ -611,18 +558,14 @@ def test_patch_preserves_waitlists_if_omitted(client, email_factory): email = email_factory(waitlists=2) patch_data = {"email": {"first_name": "Jeff"}} - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 actual = resp.json() assert len(actual["waitlists"]) == len(email.waitlists) -def test_subscribe_to_relay_newsletter_turned_into_relay_waitlist( - client, email_factory -): +def test_subscribe_to_relay_newsletter_turned_into_relay_waitlist(client, email_factory): email = email_factory() patch_data = { @@ -630,7 +573,5 @@ def test_subscribe_to_relay_newsletter_turned_into_relay_waitlist( "vpn_waitlist": {"geo": "fr", "platform": "windows"}, } - resp = client.patch( - f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True - ) + resp = client.patch(f"/ctms/{email.email_id}", json=patch_data, follow_redirects=True) assert resp.status_code == 200 # Not 400 diff --git a/tests/unit/routers/contacts/test_api_post.py b/tests/unit/routers/contacts/test_api_post.py index f07a0c83..ef16550c 100644 --- a/tests/unit/routers/contacts/test_api_post.py +++ b/tests/unit/routers/contacts/test_api_post.py @@ -11,11 +11,7 @@ def test_create_basic_no_email_id(client, dbsession): """Most straightforward contact creation succeeds when email_id is not a key.""" - contact_data = jsonable_encoder( - schemas.ContactInSchema( - email={"primary_email": "hello@example.com"} - ).model_dump(exclude_none=True) - ) + contact_data = jsonable_encoder(schemas.ContactInSchema(email={"primary_email": "hello@example.com"}).model_dump(exclude_none=True)) assert "email_id" not in contact_data["email"].keys() resp = client.post("/ctms", json=contact_data) @@ -27,9 +23,7 @@ def test_create_basic_no_email_id(client, dbsession): def test_create_basic_email_id_is_none(client, dbsession): """Most straightforward contact creation succeeds when email_id is not a key.""" - contact_data = jsonable_encoder( - schemas.ContactInSchema(email={"primary_email": "hello@example.com"}) - ) + contact_data = jsonable_encoder(schemas.ContactInSchema(email={"primary_email": "hello@example.com"})) assert contact_data["email"]["email_id"] is None resp = client.post("/ctms", json=contact_data) @@ -42,11 +36,7 @@ def test_create_basic_with_id(client, dbsession, email_factory): """Most straightforward contact creation succeeds when email_id is specified.""" provided_email_id = str(uuid4()) - contact_data = jsonable_encoder( - schemas.ContactInSchema( - email={"email_id": provided_email_id, "primary_email": "hello@example.com"} - ) - ) + contact_data = jsonable_encoder(schemas.ContactInSchema(email={"email_id": provided_email_id, "primary_email": "hello@example.com"})) assert contact_data["email"]["email_id"] == provided_email_id resp = client.post("/ctms", json=contact_data) @@ -58,9 +48,7 @@ def test_create_basic_with_id(client, dbsession, email_factory): def test_create_basic_idempotent(client, dbsession): """Creating a contact works across retries.""" - contact_data = jsonable_encoder( - schemas.ContactInSchema(email={"primary_email": "hello@example.com"}) - ) + contact_data = jsonable_encoder(schemas.ContactInSchema(email={"primary_email": "hello@example.com"})) resp = client.post("/ctms", json=contact_data) assert resp.status_code == 201 @@ -75,11 +63,7 @@ def test_create_basic_idempotent(client, dbsession): def test_create_basic_with_id_collision(client, email_factory): """Creating a contact with the same id but different data fails.""" - contact_data = jsonable_encoder( - schemas.ContactInSchema( - email={"primary_email": "hello@example.com", "email_lang": "en"} - ) - ) + contact_data = jsonable_encoder(schemas.ContactInSchema(email={"primary_email": "hello@example.com", "email_lang": "en"})) resp = client.post("/ctms", json=contact_data) assert resp.status_code == 201 @@ -100,9 +84,7 @@ def test_create_basic_with_email_collision(client, email_factory): colliding_email = "foo@example.com" email_factory(primary_email=colliding_email) - contact_data = jsonable_encoder( - schemas.ContactInSchema(email={"primary_email": colliding_email}) - ) + contact_data = jsonable_encoder(schemas.ContactInSchema(email={"primary_email": colliding_email})) resp = client.post("/ctms", json=contact_data) assert resp.status_code == 409 diff --git a/tests/unit/routers/contacts/test_api_put.py b/tests/unit/routers/contacts/test_api_put.py index 8df44b3b..7cdf92ab 100644 --- a/tests/unit/routers/contacts/test_api_put.py +++ b/tests/unit/routers/contacts/test_api_put.py @@ -3,7 +3,6 @@ import logging from uuid import uuid4 -import pytest from fastapi.encoders import jsonable_encoder from ctms import models @@ -13,9 +12,7 @@ def test_create_or_update_basic_id_is_different(client): """This should fail since we require an email_id to PUT""" - contact = ContactPutSchema( - email={"email_id": str(uuid4()), "primary_email": "hello@example.com"} - ) + contact = ContactPutSchema(email={"email_id": str(uuid4()), "primary_email": "hello@example.com"}) # This id is different from the one in the contact resp = client.put( f"/ctms/{str(uuid4())}", @@ -28,9 +25,7 @@ def test_create_or_update_basic_id_is_different(client): def test_create_or_update_basic_id_is_none(client): """This should fail since we require an email_id to PUT""" - contact_data = ContactPutSchema.model_construct( - EmailInSchema(primary_email="foo@example.com") - ) + contact_data = ContactPutSchema.model_construct(EmailInSchema(primary_email="foo@example.com")) resp = client.put(f"/ctms/{str(uuid4())}", json=jsonable_encoder(contact_data)) assert resp.status_code == 422 @@ -38,9 +33,7 @@ def test_create_or_update_basic_id_is_none(client): def test_create_or_update_basic_empty_db(client): """Most straightforward contact creation succeeds when there is no collision""" email_id = str(uuid4()) - contact_data = ContactPutSchema( - email={"email_id": email_id, "primary_email": "foo@example.com"} - ) + contact_data = ContactPutSchema(email={"email_id": email_id, "primary_email": "foo@example.com"}) resp = client.put(f"/ctms/{email_id}", json=jsonable_encoder(contact_data)) assert resp.status_code == 201 @@ -49,9 +42,7 @@ def test_create_or_update_identical(client, dbsession): """Writing the same thing twice works both times""" email_id = str(uuid4()) - contact_data = ContactPutSchema( - email={"email_id": email_id, "primary_email": "foo@example.com"} - ) + contact_data = ContactPutSchema(email={"email_id": email_id, "primary_email": "foo@example.com"}) resp = client.put(f"/ctms/{email_id}", json=jsonable_encoder(contact_data)) assert resp.status_code == 201 @@ -69,9 +60,7 @@ def test_create_or_update_change_primary_email(client, email_factory, dbsession) email_id = str(uuid4()) email_factory(email_id=email_id, primary_email="foo@example.com") - contact_data = ContactPutSchema( - email={"email_id": email_id, "primary_email": "bar@example.com"} - ) + contact_data = ContactPutSchema(email={"email_id": email_id, "primary_email": "bar@example.com"}) resp = client.put(f"/ctms/{email_id}", json=jsonable_encoder(contact_data)) assert resp.status_code == 201 @@ -84,9 +73,7 @@ def test_create_or_update_change_basket_token(client, email_factory, dbsession): """We can update a basket_token given a ctms ID""" email_id = str(uuid4()) - email_factory( - email_id=email_id, primary_email="foo@example.com", basket_token=uuid4() - ) + email_factory(email_id=email_id, primary_email="foo@example.com", basket_token=uuid4()) new_basket_token = str(uuid4()) contact_data = ContactPutSchema( @@ -118,9 +105,7 @@ def test_create_or_update_with_basket_collision(client, email_factory): "basket_token": existing_basket_token, } ) - resp = client.put( - f"/ctms/{new_contact_email_id}", json=jsonable_encoder(contact_data) - ) + resp = client.put(f"/ctms/{new_contact_email_id}", json=jsonable_encoder(contact_data)) assert resp.status_code == 409 @@ -138,9 +123,7 @@ def test_create_or_update_with_email_collision(client, email_factory): "primary_email": existing_email_address, } ) - resp = client.put( - f"/ctms/{new_contact_email_id}", json=jsonable_encoder(contact_data) - ) + resp = client.put(f"/ctms/{new_contact_email_id}", json=jsonable_encoder(contact_data)) assert resp.status_code == 409 diff --git a/tests/unit/routers/contacts/test_bulk.py b/tests/unit/routers/contacts/test_bulk.py index 909fd91d..5f3d462b 100644 --- a/tests/unit/routers/contacts/test_bulk.py +++ b/tests/unit/routers/contacts/test_bulk.py @@ -61,9 +61,7 @@ def test_get_ctms_bulk_by_timerange(client, email_factory): ) last_email = email_factory() - after = BulkRequestSchema.compressor_for_bulk_encoded_details( - first_email.email_id, first_email.update_timestamp - ) + after = BulkRequestSchema.compressor_for_bulk_encoded_details(first_email.email_id, first_email.update_timestamp) limit = 1 start = first_email.update_timestamp - timedelta(hours=12) start_time = urllib.parse.quote_plus(start.isoformat()) @@ -107,9 +105,7 @@ def test_get_ctms_bulk_by_timerange_no_results(client, email_factory): ) first_email = sorted_list[0] last_email = sorted_list[-1] - after = BulkRequestSchema.compressor_for_bulk_encoded_details( - last_email.email_id, last_email.update_timestamp - ) + after = BulkRequestSchema.compressor_for_bulk_encoded_details(last_email.email_id, last_email.update_timestamp) limit = 1 start = first_email.update_timestamp - timedelta(hours=12) start_time = urllib.parse.quote_plus(start.isoformat()) diff --git a/tests/unit/routers/contacts/test_private_api.py b/tests/unit/routers/contacts/test_private_api.py index caa25918..7cb2fe9e 100644 --- a/tests/unit/routers/contacts/test_private_api.py +++ b/tests/unit/routers/contacts/test_private_api.py @@ -1,13 +1,13 @@ """Tests for the private APIs that may be removed.""" import json -from typing import Any, Tuple +from typing import Any import pytest from ctms.schemas.contact import ContactSchema -API_TEST_CASES: Tuple[Tuple[str, Any], ...] = ( +API_TEST_CASES: tuple[tuple[str, Any], ...] = ( ("/identities", {"basket_token": "c4a7d759-bb52-457b-896b-90f1d3ef8433"}), ("/identity/332de237-cab7-4461-bcc3-48e68f42bd5c", {}), ) @@ -115,9 +115,7 @@ def test_get_identities_by_two_alt_id_match(client, email_factory): assert fxa_email resp = client.get(f"/identities?sfdc_id={sfdc_id}&fxa_primary_email={fxa_email}") - identity = json.loads( - ContactSchema.from_email(email).as_identity_response().model_dump_json() - ) + identity = json.loads(ContactSchema.from_email(email).as_identity_response().model_dump_json()) assert resp.status_code == 200 assert resp.json() == [identity] @@ -127,9 +125,7 @@ def test_get_identities_by_two_alt_id_mismatch_fails(client, email_factory): email_1 = email_factory(with_amo=True) email_2 = email_factory(with_amo=True) - resp = client.get( - f"/identities?primary_email={email_1.primary_email}&amo_user_id={email_2.amo.user_id}" - ) + resp = client.get(f"/identities?primary_email={email_1.primary_email}&amo_user_id={email_2.amo.user_id}") assert resp.status_code == 200 assert resp.json() == [] @@ -178,9 +174,7 @@ def test_get_identities_with_no_alt_ids_fails(client, dbsession): ("mofo_email_id", "cad092ec-a71a-4df5-aa92-517959caeecb"), ], ) -def test_get_identities_with_unknown_ids_fails( - client, dbsession, alt_id_name, alt_id_value -): +def test_get_identities_with_unknown_ids_fails(client, dbsession, alt_id_name, alt_id_value): """GET /identities returns an empty list if no IDs match.""" resp = client.get(f"/identities?{alt_id_name}={alt_id_value}") assert resp.status_code == 200 diff --git a/tests/unit/routers/test_platform.py b/tests/unit/routers/test_platform.py index 82d5896a..fdaea3c6 100644 --- a/tests/unit/routers/test_platform.py +++ b/tests/unit/routers/test_platform.py @@ -28,7 +28,7 @@ def test_read_version(anon_client): root_dir = here.parents[3] app.state.APP_DIR = root_dir version_path = Path(root_dir / "version.json") - with open(version_path, "r", encoding="utf8") as vp_file: + with open(version_path, encoding="utf8") as vp_file: version_contents = vp_file.read() expected = json.loads(version_contents) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index afdf1c28..4c8bd145 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,7 +1,7 @@ """Test authentication""" import logging -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import jwt import pytest @@ -39,13 +39,9 @@ def test_post_token_header(anon_client, test_token_settings, client_id_and_secre content = resp.json() assert content["token_type"] == "bearer" assert content["expires_in"] == 5 * 60 - payload = jwt.decode( - content["access_token"], test_token_settings["secret_key"], algorithms=["HS256"] - ) + payload = jwt.decode(content["access_token"], test_token_settings["secret_key"], algorithms=["HS256"]) assert payload["sub"] == f"api_client:{client_id}" - expected_expires = ( - datetime.now(timezone.utc) + test_token_settings["expires_delta"] - ).timestamp() + expected_expires = (datetime.now(UTC) + test_token_settings["expires_delta"]).timestamp() assert -2.0 < (expected_expires - payload["exp"]) < 2.0 @@ -64,13 +60,9 @@ def test_post_token_form_data(anon_client, test_token_settings, client_id_and_se assert resp.status_code == 200 content = resp.json() assert content["token_type"] == "bearer" - payload = jwt.decode( - content["access_token"], test_token_settings["secret_key"], algorithms=["HS256"] - ) + payload = jwt.decode(content["access_token"], test_token_settings["secret_key"], algorithms=["HS256"]) assert payload["sub"] == f"api_client:{client_id}" - expected_expires = ( - datetime.now(timezone.utc) + test_token_settings["expires_delta"] - ).timestamp() + expected_expires = (datetime.now(UTC) + test_token_settings["expires_delta"]).timestamp() assert -2.0 < (expected_expires - payload["exp"]) < 2.0 @@ -83,9 +75,7 @@ def test_post_token_succeeds_no_grant(anon_client, client_id_and_secret): assert resp.status_code == 200 -def test_post_token_succeeds_refresh_grant( - anon_client, test_token_settings, client_id_and_secret -): +def test_post_token_succeeds_refresh_grant(anon_client, test_token_settings, client_id_and_secret): """If grant_type is refresh_token, the token grant is successful.""" client_id, client_secret = client_id_and_secret resp = anon_client.post( @@ -131,9 +121,7 @@ def test_post_token_fails_unknown_api_client(anon_client, client_id_and_secret, """Authentication failes on unknown api_client ID.""" good_id, good_secret = client_id_and_secret with caplog.at_level(logging.INFO): - resp = anon_client.post( - "/token", auth=HTTPBasicAuth(good_id + "x", good_secret) - ) + resp = anon_client.post("/token", auth=HTTPBasicAuth(good_id + "x", good_secret)) assert resp.status_code == 400 assert resp.json() == {"detail": "Incorrect username or password"} assert caplog.records[0].token_creds_from == "header" @@ -144,18 +132,14 @@ def test_post_token_fails_bad_credentials(anon_client, client_id_and_secret, cap """Authentication fails on bad credentials.""" good_id, good_secret = client_id_and_secret with caplog.at_level(logging.INFO): - resp = anon_client.post( - "/token", auth=HTTPBasicAuth(good_id, good_secret + "x") - ) + resp = anon_client.post("/token", auth=HTTPBasicAuth(good_id, good_secret + "x")) assert resp.status_code == 400 assert resp.json() == {"detail": "Incorrect username or password"} assert caplog.records[0].token_creds_from == "header" assert caplog.records[0].token_fail == "Bad credentials" -def test_post_token_fails_disabled_client( - dbsession, anon_client, client_id_and_secret, caplog -): +def test_post_token_fails_disabled_client(dbsession, anon_client, client_id_and_secret, caplog): """Authentication fails when the client is disabled.""" client_id, client_secret = client_id_and_secret api_client = get_api_client_by_id(dbsession, client_id) @@ -169,16 +153,12 @@ def test_post_token_fails_disabled_client( assert caplog.records[0].token_fail == "Client disabled" -def test_get_ctms_with_token( - email_factory, anon_client, test_token_settings, client_id_and_secret -): +def test_get_ctms_with_token(email_factory, anon_client, test_token_settings, client_id_and_secret): """An authenticated API can be fetched with a valid token""" email = email_factory() client_id = client_id_and_secret[0] - token = create_access_token( - {"sub": f"api_client:{client_id}"}, **test_token_settings - ) + token = create_access_token({"sub": f"api_client:{client_id}"}, **test_token_settings) token_headers = jwt.get_unverified_header(token) assert token_headers == { "alg": "HS256", @@ -191,18 +171,14 @@ def test_get_ctms_with_token( assert resp.status_code == 200 -def test_successful_login_tracks_last_access( - dbsession, email_factory, anon_client, test_token_settings, client_id_and_secret -): +def test_successful_login_tracks_last_access(dbsession, email_factory, anon_client, test_token_settings, client_id_and_secret): client_id = client_id_and_secret[0] email = email_factory() api_client = get_api_client_by_id(dbsession, client_id) before = api_client.last_access - token = create_access_token( - {"sub": f"api_client:{client_id}"}, **test_token_settings - ) + token = create_access_token({"sub": f"api_client:{client_id}"}, **test_token_settings) anon_client.get( f"/ctms/{email.email_id}", headers={"Authorization": f"Bearer {token}"}, @@ -213,9 +189,7 @@ def test_successful_login_tracks_last_access( assert before != after -def test_get_ctms_with_invalid_token_fails( - email_factory, anon_client, test_token_settings, client_id_and_secret, caplog -): +def test_get_ctms_with_invalid_token_fails(email_factory, anon_client, test_token_settings, client_id_and_secret, caplog): """Calling an authenticated API with an invalid token is an error""" email = email_factory() @@ -235,9 +209,7 @@ def test_get_ctms_with_invalid_token_fails( assert caplog.records[0].auth_fail == "No or bad token" -def test_get_ctms_with_invalid_namespace_fails( - email_factory, anon_client, test_token_settings, client_id_and_secret, caplog -): +def test_get_ctms_with_invalid_namespace_fails(email_factory, anon_client, test_token_settings, client_id_and_secret, caplog): """Calling an authenticated API with an unexpected namespace is an error""" email = email_factory() @@ -253,16 +225,12 @@ def test_get_ctms_with_invalid_namespace_fails( assert caplog.records[0].auth_fail == "Bad namespace" -def test_get_ctms_with_unknown_client_fails( - email_factory, anon_client, test_token_settings, client_id_and_secret, caplog -): +def test_get_ctms_with_unknown_client_fails(email_factory, anon_client, test_token_settings, client_id_and_secret, caplog): """A token with an unknown (deleted?) API client name is an error""" email = email_factory() client_id = client_id_and_secret[0] - token = create_access_token( - {"sub": f"api_client:not_{client_id}"}, **test_token_settings - ) + token = create_access_token({"sub": f"api_client:not_{client_id}"}, **test_token_settings) with caplog.at_level(logging.INFO): resp = anon_client.get( f"/ctms/{email.email_id}", @@ -273,17 +241,13 @@ def test_get_ctms_with_unknown_client_fails( assert caplog.records[0].auth_fail == "No client record" -def test_get_ctms_with_expired_token_fails( - email_factory, anon_client, test_token_settings, client_id_and_secret, caplog -): +def test_get_ctms_with_expired_token_fails(email_factory, anon_client, test_token_settings, client_id_and_secret, caplog): """Calling an authenticated API with an expired token is an error""" email = email_factory() - yesterday = datetime.now(timezone.utc) - timedelta(days=1) + yesterday = datetime.now(UTC) - timedelta(days=1) client_id = client_id_and_secret[0] - token = create_access_token( - {"sub": f"api_client:{client_id}"}, **test_token_settings, now=yesterday - ) + token = create_access_token({"sub": f"api_client:{client_id}"}, **test_token_settings, now=yesterday) with caplog.at_level(logging.INFO): resp = anon_client.get( f"/ctms/{email.email_id}", @@ -306,9 +270,7 @@ def test_get_ctms_with_disabled_client_fails( email = email_factory() client_id = client_id_and_secret[0] - token = create_access_token( - {"sub": f"api_client:{client_id}"}, **test_token_settings - ) + token = create_access_token({"sub": f"api_client:{client_id}"}, **test_token_settings) api_client = get_api_client_by_id(dbsession, client_id) api_client.enabled = False dbsession.commit() diff --git a/tests/unit/test_crud.py b/tests/unit/test_crud.py index 62b2eab4..2d4ea782 100644 --- a/tests/unit/test_crud.py +++ b/tests/unit/test_crud.py @@ -1,6 +1,6 @@ """Test database operations""" -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from uuid import uuid4 import pytest @@ -8,11 +8,6 @@ from ctms.crud import ( count_total_contacts, - create_amo, - create_email, - create_fxa, - create_mofo, - create_newsletter, create_or_update_contact, get_bulk_contacts, get_contact_by_email_id, @@ -24,10 +19,7 @@ from ctms.database import ScopedSessionLocal from ctms.models import Email from ctms.schemas import ( - AddOnsInSchema, EmailInSchema, - FirefoxAccountsInSchema, - MozillaFoundationInSchema, NewsletterInSchema, ) from ctms.schemas.contact import ContactPutSchema @@ -99,17 +91,15 @@ def test_get_contact_by_email_id_miss(dbsession): (False, 2), ], ) -def test_get_bulk_contacts_mofo_relevant( - dbsession, email_factory, mofo_relevant_flag, num_contacts_returned -): +def test_get_bulk_contacts_mofo_relevant(dbsession, email_factory, mofo_relevant_flag, num_contacts_returned): email_factory() email_factory(with_mofo=True, mofo__mofo_relevant=True) email_factory(with_mofo=True, mofo__mofo_relevant=False) contacts = get_bulk_contacts( dbsession, - start_time=datetime.now(timezone.utc) - timedelta(minutes=1), - end_time=datetime.now(timezone.utc) + timedelta(minutes=1), + start_time=datetime.now(UTC) - timedelta(minutes=1), + end_time=datetime.now(UTC) + timedelta(minutes=1), limit=3, mofo_relevant=mofo_relevant_flag, ) @@ -117,7 +107,7 @@ def test_get_bulk_contacts_mofo_relevant( def test_get_bulk_contacts_time_bounds(dbsession, email_factory): - start_time = datetime.now(timezone.utc) + start_time = datetime.now(UTC) end_time = start_time + timedelta(minutes=2) email_factory(update_timestamp=start_time - timedelta(minutes=1)) @@ -130,8 +120,8 @@ def test_get_bulk_contacts_time_bounds(dbsession, email_factory): contacts = get_bulk_contacts( dbsession, - start_time=datetime.now(timezone.utc) - timedelta(minutes=1), - end_time=datetime.now(timezone.utc) + timedelta(minutes=1), + start_time=datetime.now(UTC) - timedelta(minutes=1), + end_time=datetime.now(UTC) + timedelta(minutes=1), limit=5, ) @@ -146,8 +136,8 @@ def test_get_bulk_contacts_limited(dbsession, email_factory): contacts = get_bulk_contacts( dbsession, - start_time=datetime.now(timezone.utc) - timedelta(minutes=1), - end_time=datetime.now(timezone.utc) + timedelta(minutes=1), + start_time=datetime.now(UTC) - timedelta(minutes=1), + end_time=datetime.now(UTC) + timedelta(minutes=1), limit=5, ) assert len(contacts) == 5 @@ -159,8 +149,8 @@ def test_get_bulk_contacts_after_email_id(dbsession, email_factory): [contact] = get_bulk_contacts( dbsession, - start_time=datetime.now(timezone.utc) - timedelta(minutes=1), - end_time=datetime.now(timezone.utc) + timedelta(minutes=1), + start_time=datetime.now(UTC) - timedelta(minutes=1), + end_time=datetime.now(UTC) + timedelta(minutes=1), limit=1, after_email_id=str(first_email.email_id), ) @@ -173,8 +163,8 @@ def test_get_bulk_contacts_one(dbsession, email_factory): [contact] = get_bulk_contacts( dbsession, - start_time=datetime.now(timezone.utc) - timedelta(minutes=1), - end_time=datetime.now(timezone.utc) + timedelta(minutes=1), + start_time=datetime.now(UTC) - timedelta(minutes=1), + end_time=datetime.now(UTC) + timedelta(minutes=1), limit=10, ) assert contact.email.email_id == email.email_id @@ -183,8 +173,8 @@ def test_get_bulk_contacts_one(dbsession, email_factory): def test_get_bulk_contacts_none(dbsession): bulk_contact_list = get_bulk_contacts( dbsession, - start_time=datetime.now(timezone.utc) + timedelta(days=1), - end_time=datetime.now(timezone.utc) + timedelta(days=1), + start_time=datetime.now(UTC) + timedelta(days=1), + end_time=datetime.now(UTC) + timedelta(days=1), limit=10, ) assert bulk_contact_list == [] @@ -242,9 +232,7 @@ def test_get_contact_by_any_id_missing(dbsession, email_factory): ("mofo_contact_id", "5e499cc0-eeb5-4f0e-aae6-a101721874b8"), ], ) -def test_get_multiple_contacts_by_any_id( - dbsession, email_factory, alt_id_name, alt_id_value -): +def test_get_multiple_contacts_by_any_id(dbsession, email_factory, alt_id_name, alt_id_value): """Two contacts can share the same: - fxa primary_email - amo user_id @@ -263,9 +251,7 @@ def test_get_multiple_contacts_by_any_id( with_mofo=True, mofo__mofo_contact_id="5e499cc0-eeb5-4f0e-aae6-a101721874b8", ) - [contact_a, contact_b] = get_contacts_by_any_id( - dbsession, **{alt_id_name: alt_id_value} - ) + [contact_a, contact_b] = get_contacts_by_any_id(dbsession, **{alt_id_name: alt_id_value}) assert contact_a.email.email_id != contact_b.email.email_id @@ -279,9 +265,7 @@ def test_create_or_update_contact_related_objects(dbsession, email_factory): new_source = "http://waitlists.example.com/" putdata = ContactPutSchema( email=EmailInSchema(email_id=email.email_id, primary_email=email.primary_email), - newsletters=[ - NewsletterInSchema(name=email.newsletters[0].name, source=new_source) - ], + newsletters=[NewsletterInSchema(name=email.newsletters[0].name, source=new_source)], waitlists=[WaitlistInSchema(name=email.waitlists[0].name, source=new_source)], ) create_or_update_contact(dbsession, email.email_id, putdata, None) @@ -308,9 +292,7 @@ def test_create_or_update_contact_timestamps(dbsession, email_factory): new_source = "http://waitlists.example.com" putdata = ContactPutSchema( email=EmailInSchema(email_id=email.email_id, primary_email=email.primary_email), - newsletters=[ - NewsletterInSchema(name=email.newsletters[0].name, source=new_source) - ], + newsletters=[NewsletterInSchema(name=email.newsletters[0].name, source=new_source)], waitlists=[WaitlistInSchema(name=email.waitlists[0].name, source=new_source)], ) create_or_update_contact(dbsession, email.email_id, putdata, None) diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py index 1400b6fc..2ce6fd40 100644 --- a/tests/unit/test_log.py +++ b/tests/unit/test_log.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- """Tests for logging helpers""" import logging -from unittest.mock import patch import pytest from dockerflow.logging import JsonLogFormatter @@ -64,15 +62,15 @@ def test_token_request_log(anon_client, client_id_and_secret, caplog): def test_log_omits_emails(client, email_factory, caplog): """The logger omits emails from query params.""" email = email_factory(with_fxa=True) - url = ( - f"/ctms?primary_email={email.primary_email}&fxa_primary_email={email.fxa.primary_email}" - f"&email_id={email.email_id}" - ) + url = f"/ctms?primary_email={email.primary_email}&fxa_primary_email={email.fxa.primary_email}&email_id={email.email_id}" with caplog.at_level(logging.INFO): resp = client.get(url) assert resp.status_code == 200 assert len(caplog.records) == 1 log = caplog.records[0] + assert email.primary_email not in log.message + assert email.fxa.primary_email not in log.message + assert str(email.email_id) not in log.message def test_log_crash(client, caplog): diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 74bddaee..ab950f53 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -123,9 +123,7 @@ def assert_duration_metric_obs( } bucket_labels = labels.copy() bucket_labels["le"] = str(limit) - assert ( - metrics_registry.get_sample_value(f"{base_name}_bucket", bucket_labels) == count - ) + assert metrics_registry.get_sample_value(f"{base_name}_bucket", bucket_labels) == count assert metrics_registry.get_sample_value(f"{base_name}_count", labels) == count assert metrics_registry.get_sample_value(f"{base_name}_sum", labels) < limit @@ -191,9 +189,7 @@ def test_bad_api_request(client, dbsession, registry, email_id, status_code): path = "/ctms/{email_id}" assert_request_metric_inc(registry, "GET", path, status_code) status_code_family = str(status_code)[0] + "xx" - assert_api_request_metric_inc( - registry, "GET", path, "test_client", status_code_family - ) + assert_api_request_metric_inc(registry, "GET", path, "test_client", status_code_family) def test_crash_request(client, dbsession, registry):