diff --git a/fast_flights/__init__.py b/fast_flights/__init__.py index db669926..f78d3081 100644 --- a/fast_flights/__init__.py +++ b/fast_flights/__init__.py @@ -1,7 +1,7 @@ from .cookies_impl import Cookies -from .core import get_flights_from_filter, get_flights +from .core import get_flights_from_filter, get_flights, DataSource, FetchMode from .filter import create_filter -from .flights_impl import Airport, FlightData, Passengers, TFSData +from .flights_impl import Airport, FlightData, Passengers, TFSData, TripType, SeatType from .schema import Flight, Result from .search import search_airport @@ -17,4 +17,8 @@ "search_airport", "Cookies", "get_flights", + "DataSource", + "FetchMode", + "TripType", + "SeatType", ] diff --git a/fast_flights/core.py b/fast_flights/core.py index 6b11aafb..a3af62b4 100644 --- a/fast_flights/core.py +++ b/fast_flights/core.py @@ -1,12 +1,12 @@ import re import json -from typing import List, Literal, Optional, Union, overload +from typing import List, Literal, Optional, Union, overload, get_args from selectolax.lexbor import LexborHTMLParser, LexborNode from .decoder import DecodedResult, ResultDecoder from .schema import Flight, Result -from .flights_impl import FlightData, Passengers +from .flights_impl import FlightData, Passengers, TripType, SeatType from .filter import TFSData from .fallback_playwright import fallback_playwright_fetch from .bright_data_fetch import bright_data_fetch @@ -14,6 +14,7 @@ DataSource = Literal['html', 'js'] +FetchMode = Literal["common", "fallback", "force-fallback", "local", "bright-data"] def fetch(params: dict) -> Response: client = Client(impersonate="chrome_126", verify=False) @@ -26,7 +27,7 @@ def get_flights_from_filter( filter: TFSData, currency: str = "", *, - mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common", + mode: FetchMode = "common", data_source: Literal['js'] = ..., ) -> Union[DecodedResult, None]: ... @@ -35,7 +36,7 @@ def get_flights_from_filter( filter: TFSData, currency: str = "", *, - mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common", + mode: FetchMode = "common", data_source: Literal['html'], ) -> Result: ... @@ -43,9 +44,23 @@ def get_flights_from_filter( filter: TFSData, currency: str = "", *, - mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common", + mode: FetchMode = "common", data_source: DataSource = 'html', ) -> Union[Result, DecodedResult, None]: + # Validate mode parameter + valid_modes = get_args(FetchMode) + if mode not in valid_modes: + raise ValueError( + f"Invalid fetch mode: {mode}. Must be one of {list(valid_modes)}" + ) + + # Validate data_source parameter + valid_data_sources = get_args(DataSource) + if data_source not in valid_data_sources: + raise ValueError( + f"Invalid data_source: {data_source}. Must be one of {list(valid_data_sources)}" + ) + data = filter.as_b64() params = { @@ -86,10 +101,10 @@ def get_flights_from_filter( def get_flights( *, flight_data: List[FlightData], - trip: Literal["round-trip", "one-way", "multi-city"], + trip: TripType, passengers: Passengers, - seat: Literal["economy", "premium-economy", "business", "first"], - fetch_mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common", + seat: SeatType, + fetch_mode: FetchMode = "common", max_stops: Optional[int] = None, data_source: DataSource = 'html', ) -> Union[Result, DecodedResult, None]: diff --git a/fast_flights/filter.py b/fast_flights/filter.py index 67171f0a..22702b8d 100644 --- a/fast_flights/filter.py +++ b/fast_flights/filter.py @@ -1,12 +1,12 @@ -from typing import Literal, List, Optional -from .flights_impl import FlightData, Passengers, TFSData +from typing import List, Optional +from .flights_impl import FlightData, Passengers, TFSData, TripType, SeatType def create_filter( *, flight_data: List[FlightData], - trip: Literal["round-trip", "one-way", "multi-city"], + trip: TripType, passengers: Passengers, - seat: Literal["economy", "premium-economy", "business", "first"], + seat: SeatType, max_stops: Optional[int] = None, ) -> TFSData: """Create a filter. (``?tfs=``) diff --git a/fast_flights/flights_impl.py b/fast_flights/flights_impl.py index 5bd49e0d..6a9c0873 100644 --- a/fast_flights/flights_impl.py +++ b/fast_flights/flights_impl.py @@ -1,8 +1,9 @@ """Typed implementation of flights_pb2.py""" import base64 +from datetime import datetime from dataclasses import dataclass -from typing import Any, List, Optional, TYPE_CHECKING, Literal, Union +from typing import Any, List, Optional, TYPE_CHECKING, Literal, Union, get_args from . import flights_pb2 as PB from ._generated_enum import Airport @@ -12,6 +13,10 @@ AIRLINE_ALLIANCES = ["SKYTEAM", "STAR_ALLIANCE", "ONEWORLD"] +# Type aliases for validation +TripType = Literal["round-trip", "one-way", "multi-city"] +SeatType = Literal["economy", "premium-economy", "business", "first"] + class FlightData: """Represents flight data. @@ -39,6 +44,20 @@ def __init__( max_stops: Optional[int] = None, airlines: Optional[List[str]] = None, ): + # Validate date format and ensure it's not in the past + try: + date_obj = datetime.strptime(date, "%Y-%m-%d").date() + except ValueError: + raise ValueError( + f"Invalid date format: {date}. Date must be in YYYY-MM-DD format." + ) + + today = datetime.now().date() + if date_obj < today: + raise ValueError( + f"Date cannot be in the past. Provided date: {date}, Today: {today}" + ) + self.date = date self.from_airport = ( from_airport.value if isinstance(from_airport, Airport) else from_airport @@ -163,9 +182,9 @@ def as_b64(self) -> bytes: def from_interface( *, flight_data: List[FlightData], - trip: Literal["round-trip", "one-way", "multi-city"], + trip: TripType, passengers: Passengers, - seat: Literal["economy", "premium-economy", "business", "first"], + seat: SeatType, max_stops: Optional[int] = None, # Add max_stops to the method signature ): """Use ``?tfs=`` from an interface. @@ -177,6 +196,30 @@ def from_interface( seat ("economy" | "premium-economy" | "business" | "first"): Seat. max_stops (int, optional): Maximum number of stops. """ + # Validate trip parameter + valid_trips = get_args(TripType) + if trip not in valid_trips: + raise ValueError( + f"Invalid trip type: {trip}. Must be one of {list(valid_trips)}" + ) + + # Validate seat parameter + valid_seats = get_args(SeatType) + if seat not in valid_seats: + raise ValueError( + f"Invalid seat type: {seat}. Must be one of {list(valid_seats)}" + ) + + # Validate flight_data + if not flight_data: + raise ValueError("flight_data must contain at least one FlightData object") + + # Validate trip-specific requirements + if trip == "round-trip" and len(flight_data) < 2: + raise ValueError( + f"round-trip requires at least 2 FlightData objects, but got {len(flight_data)}" + ) + trip_t = { "round-trip": PB.Trip.ROUND_TRIP, "one-way": PB.Trip.ONE_WAY,