Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions fast_flights/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .cookies_impl import Cookies
from .core import get_flights_from_filter, get_flights
from .core import get_flights_from_filter, get_flights, DataSource, FetchMode
from .filter import create_filter
from .flights_impl import Airport, FlightData, Passengers, TFSData
from .flights_impl import Airport, FlightData, Passengers, TFSData, TripType, SeatType
from .schema import Flight, Result
from .search import search_airport

Expand All @@ -17,4 +17,8 @@
"search_airport",
"Cookies",
"get_flights",
"DataSource",
"FetchMode",
"TripType",
"SeatType",
]
31 changes: 23 additions & 8 deletions fast_flights/core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import re
import json
from typing import List, Literal, Optional, Union, overload
from typing import List, Literal, Optional, Union, overload, get_args

from selectolax.lexbor import LexborHTMLParser, LexborNode

from .decoder import DecodedResult, ResultDecoder
from .schema import Flight, Result
from .flights_impl import FlightData, Passengers
from .flights_impl import FlightData, Passengers, TripType, SeatType
from .filter import TFSData
from .fallback_playwright import fallback_playwright_fetch
from .bright_data_fetch import bright_data_fetch
from .primp import Client, Response


DataSource = Literal['html', 'js']
FetchMode = Literal["common", "fallback", "force-fallback", "local", "bright-data"]

def fetch(params: dict) -> Response:
client = Client(impersonate="chrome_126", verify=False)
Expand All @@ -26,7 +27,7 @@ def get_flights_from_filter(
filter: TFSData,
currency: str = "",
*,
mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common",
mode: FetchMode = "common",
data_source: Literal['js'] = ...,
) -> Union[DecodedResult, None]: ...

Expand All @@ -35,17 +36,31 @@ def get_flights_from_filter(
filter: TFSData,
currency: str = "",
*,
mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common",
mode: FetchMode = "common",
data_source: Literal['html'],
) -> Result: ...

def get_flights_from_filter(
filter: TFSData,
currency: str = "",
*,
mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common",
mode: FetchMode = "common",
data_source: DataSource = 'html',
) -> Union[Result, DecodedResult, None]:
# Validate mode parameter
valid_modes = get_args(FetchMode)
if mode not in valid_modes:
raise ValueError(
f"Invalid fetch mode: {mode}. Must be one of {list(valid_modes)}"
)

# Validate data_source parameter
valid_data_sources = get_args(DataSource)
if data_source not in valid_data_sources:
raise ValueError(
f"Invalid data_source: {data_source}. Must be one of {list(valid_data_sources)}"
)

data = filter.as_b64()

params = {
Expand Down Expand Up @@ -86,10 +101,10 @@ def get_flights_from_filter(
def get_flights(
*,
flight_data: List[FlightData],
trip: Literal["round-trip", "one-way", "multi-city"],
trip: TripType,
passengers: Passengers,
seat: Literal["economy", "premium-economy", "business", "first"],
fetch_mode: Literal["common", "fallback", "force-fallback", "local", "bright-data"] = "common",
seat: SeatType,
fetch_mode: FetchMode = "common",
max_stops: Optional[int] = None,
data_source: DataSource = 'html',
) -> Union[Result, DecodedResult, None]:
Expand Down
8 changes: 4 additions & 4 deletions fast_flights/filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Literal, List, Optional
from .flights_impl import FlightData, Passengers, TFSData
from typing import List, Optional
from .flights_impl import FlightData, Passengers, TFSData, TripType, SeatType

def create_filter(
*,
flight_data: List[FlightData],
trip: Literal["round-trip", "one-way", "multi-city"],
trip: TripType,
passengers: Passengers,
seat: Literal["economy", "premium-economy", "business", "first"],
seat: SeatType,
max_stops: Optional[int] = None,
) -> TFSData:
"""Create a filter. (``?tfs=``)
Expand Down
49 changes: 46 additions & 3 deletions fast_flights/flights_impl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Typed implementation of flights_pb2.py"""

import base64
from datetime import datetime
from dataclasses import dataclass
from typing import Any, List, Optional, TYPE_CHECKING, Literal, Union
from typing import Any, List, Optional, TYPE_CHECKING, Literal, Union, get_args

from . import flights_pb2 as PB
from ._generated_enum import Airport
Expand All @@ -12,6 +13,10 @@

AIRLINE_ALLIANCES = ["SKYTEAM", "STAR_ALLIANCE", "ONEWORLD"]

# Type aliases for validation
TripType = Literal["round-trip", "one-way", "multi-city"]
SeatType = Literal["economy", "premium-economy", "business", "first"]

class FlightData:
"""Represents flight data.

Expand Down Expand Up @@ -39,6 +44,20 @@ def __init__(
max_stops: Optional[int] = None,
airlines: Optional[List[str]] = None,
):
# Validate date format and ensure it's not in the past
try:
date_obj = datetime.strptime(date, "%Y-%m-%d").date()
except ValueError:
raise ValueError(
f"Invalid date format: {date}. Date must be in YYYY-MM-DD format."
)

today = datetime.now().date()
if date_obj < today:
raise ValueError(
f"Date cannot be in the past. Provided date: {date}, Today: {today}"
)

self.date = date
self.from_airport = (
from_airport.value if isinstance(from_airport, Airport) else from_airport
Expand Down Expand Up @@ -163,9 +182,9 @@ def as_b64(self) -> bytes:
def from_interface(
*,
flight_data: List[FlightData],
trip: Literal["round-trip", "one-way", "multi-city"],
trip: TripType,
passengers: Passengers,
seat: Literal["economy", "premium-economy", "business", "first"],
seat: SeatType,
max_stops: Optional[int] = None, # Add max_stops to the method signature
):
"""Use ``?tfs=`` from an interface.
Expand All @@ -177,6 +196,30 @@ def from_interface(
seat ("economy" | "premium-economy" | "business" | "first"): Seat.
max_stops (int, optional): Maximum number of stops.
"""
# Validate trip parameter
valid_trips = get_args(TripType)
if trip not in valid_trips:
raise ValueError(
f"Invalid trip type: {trip}. Must be one of {list(valid_trips)}"
)

# Validate seat parameter
valid_seats = get_args(SeatType)
if seat not in valid_seats:
raise ValueError(
f"Invalid seat type: {seat}. Must be one of {list(valid_seats)}"
)

# Validate flight_data
if not flight_data:
raise ValueError("flight_data must contain at least one FlightData object")

# Validate trip-specific requirements
if trip == "round-trip" and len(flight_data) < 2:
raise ValueError(
f"round-trip requires at least 2 FlightData objects, but got {len(flight_data)}"
)

trip_t = {
"round-trip": PB.Trip.ROUND_TRIP,
"one-way": PB.Trip.ONE_WAY,
Expand Down