Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,75 @@

import pandas as pd

# Optional JAX availability detection
try:
import jax # noqa: F401

HAS_JAX = True
except ImportError:
HAS_JAX = False

# Optional torch availability detection
try:
import torch

HAS_TORCH = True
HAS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
HAS_CUDA = torch.cuda.is_available()
except ImportError:
HAS_TORCH = False
HAS_MPS = False
HAS_CUDA = False

# Optional CuPy availability detection
try:
import cupy

HAS_CUPY = cupy.cuda.runtime.getDeviceCount() > 0
except (ImportError, Exception):
HAS_CUPY = False

# Backends that accept a backend= argument when called through pyfixest runners
_PYFIXEST_BACKENDS = {
"scipy",
"numba",
"rust",
"jax",
"torch_cpu",
"torch_mps",
"torch_cuda",
"torch_cuda32",
"cupy",
"cupy32",
"cupy64",
}

# =============================================================================
# Helpers
# =============================================================================


def _append_optional_backends(estimators, label_prefix, runner_func, func_name):
"""Append JAX + torch backend estimators based on runtime availability."""
optional = []
if HAS_JAX:
optional.append(("jax", "jax"))
if HAS_TORCH:
optional.append(("torch_cpu", "torch_cpu"))
if HAS_MPS:
optional.append(("torch_mps", "torch_mps"))
if HAS_CUDA:
optional.append(("torch_cuda", "torch_cuda"))
optional.append(("torch_cuda32", "torch_cuda32"))
if HAS_CUPY:
optional.append(("cupy64", "cupy64"))
optional.append(("cupy32", "cupy32"))
for suffix, backend in optional:
estimators.append(
(f"{label_prefix} ({suffix})", backend, runner_func, False, func_name)
)


# =============================================================================
# Estimator functions (run in main process for JIT caching)
# =============================================================================
Expand Down Expand Up @@ -220,6 +289,11 @@ def get_estimators(
False,
"pyfixest_feols",
),
]
_append_optional_backends(
estimators, "pyfixest.feols", run_pyfixest_feols, "pyfixest_feols"
)
estimators += [
(
"linearmodels.AbsorbingLS",
"absorbingls",
Expand Down Expand Up @@ -263,6 +337,9 @@ def get_estimators(
"pyfixest_fepois",
),
]
_append_optional_backends(
estimators, "pyfixest.fepois", run_pyfixest_fepois, "pyfixest_fepois"
)
formulas = {
2: "negbin_y ~ x1 | indiv_id + year",
3: "negbin_y ~ x1 | indiv_id + year + firm_id",
Expand Down Expand Up @@ -291,6 +368,12 @@ def get_estimators(
"pyfixest_feglm_logit",
),
]
_append_optional_backends(
estimators,
"pyfixest.feglm_logit",
run_pyfixest_feglm_logit,
"pyfixest_feglm_logit",
)
formulas = {
2: "binary_y ~ x1 | indiv_id + year",
3: "binary_y ~ x1 | indiv_id + year + firm_id",
Expand All @@ -310,6 +393,7 @@ def parse_dataset_name(name: str) -> tuple[str, int]:
"500k": 500_000,
"1m": 1_000_000,
"2m": 2_000_000,
"3m": 3_000_000,
"5m": 5_000_000,
}
parts = name.rsplit("_", 1)
Expand All @@ -333,11 +417,17 @@ def run_benchmark(
timeout_estimators: set[str] | None = None,
formulas_override: dict[int, str] | None = None,
allowed_datasets: set[str] | None = None,
pyfixest_only: bool = False,
backend_filter: set[str] | None = None,
) -> None:
"""Run benchmarks on all datasets in data_dir."""
if timeout_estimators is None:
timeout_estimators = set()
estimators, formulas = get_estimators(benchmark_type, timeout_estimators)
if pyfixest_only:
estimators = [e for e in estimators if e[1] in _PYFIXEST_BACKENDS]
if backend_filter:
estimators = [e for e in estimators if e[1] in backend_filter]
if formulas_override:
formulas = formulas_override

Expand Down Expand Up @@ -420,7 +510,7 @@ def run_benchmark(
print(f"{elapsed:.3f}s")
else:
# Run in main process
if backend_or_func in ("scipy", "numba", "rust"):
if backend_or_func in _PYFIXEST_BACKENDS:
elapsed = func(data, formula, backend_or_func)
else:
elapsed = func(data, formula)
Expand Down Expand Up @@ -527,6 +617,17 @@ def main():
default=None,
help="Filter datasets by name (e.g., 'simple' to exclude 'difficult')",
)
parser.add_argument(
"--pyfixest-only",
action="store_true",
help="Skip non-pyfixest estimators (linearmodels, statsmodels)",
)
parser.add_argument(
"--backends",
type=str,
default=None,
help="Comma-separated list of backends to run (e.g., 'torch_cuda,torch_cuda32,cupy64,cupy32')",
)
args = parser.parse_args()

config = load_config("bench.json")
Expand All @@ -537,6 +638,10 @@ def main():
formulas_override = get_formulas_from_config(config, args.type)
allowed_datasets = get_allowed_datasets(config, args.type)

backend_filter = None
if args.backends:
backend_filter = set(b.strip() for b in args.backends.split(","))

run_benchmark(
args.data_dir,
args.output,
Expand All @@ -546,6 +651,8 @@ def main():
timeout_estimators,
formulas_override,
allowed_datasets,
pyfixest_only=args.pyfixest_only,
backend_filter=backend_filter,
)


Expand Down
110 changes: 110 additions & 0 deletions benchmarks/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"timeout_secs": {
"python": 60
},
"python_timeout_estimators": ["linearmodels.AbsorbingLS", "statsmodels.OLS"],
"iterations": {
"n_iters": 3,
"burn_in": 1
},
"formulas": {
"ols": {
"2": {
"python": "y ~ x1 | indiv_id + year"
},
"3": {
"python": "y ~ x1 | indiv_id + year + firm_id"
}
},
"poisson": {
"2": {
"python": "negbin_y ~ x1 | indiv_id + year"
},
"3": {
"python": "negbin_y ~ x1 | indiv_id + year + firm_id"
}
},
"logit": {
"2": {
"python": "binary_y ~ x1 | indiv_id + year"
},
"3": {
"python": "binary_y ~ x1 | indiv_id + year + firm_id"
}
}
},
"datasets": [
{ "name": "simple_1k", "n": 1000, "type": "simple" },
{ "name": "difficult_1k", "n": 1000, "type": "difficult" },
{ "name": "simple_10k", "n": 10000, "type": "simple" },
{ "name": "difficult_10k", "n": 10000, "type": "difficult" },
{ "name": "simple_100k", "n": 100000, "type": "simple" },
{ "name": "difficult_100k", "n": 100000, "type": "difficult" },
{ "name": "simple_500k", "n": 500000, "type": "simple" },
{ "name": "difficult_500k", "n": 500000, "type": "difficult" },
{ "name": "simple_1m", "n": 1000000, "type": "simple" },
{ "name": "difficult_1m", "n": 1000000, "type": "difficult" },
{ "name": "simple_2m", "n": 2000000, "type": "simple" },
{ "name": "difficult_2m", "n": 2000000, "type": "difficult" },
{ "name": "simple_3m", "n": 3000000, "type": "simple" },
{ "name": "difficult_3m", "n": 3000000, "type": "difficult" },
{ "name": "simple_5m", "n": 5000000, "type": "simple" },
{ "name": "difficult_5m", "n": 5000000, "type": "difficult" }
],
"datasets_by_type": {
"ols": [
"simple_1k",
"difficult_1k",
"simple_10k",
"difficult_10k",
"simple_100k",
"difficult_100k",
"simple_500k",
"difficult_500k",
"simple_1m",
"difficult_1m",
"simple_2m",
"difficult_2m",
"simple_3m",
"difficult_3m",
"simple_5m",
"difficult_5m"
],
"poisson": [
"simple_1k",
"difficult_1k",
"simple_10k",
"difficult_10k",
"simple_100k",
"difficult_100k",
"simple_500k",
"difficult_500k",
"simple_1m",
"difficult_1m",
"simple_2m",
"difficult_2m",
"simple_3m",
"difficult_3m",
"simple_5m",
"difficult_5m"
],
"logit": [
"simple_1k",
"difficult_1k",
"simple_10k",
"difficult_10k",
"simple_100k",
"difficult_100k",
"simple_500k",
"difficult_500k",
"simple_1m",
"difficult_1m",
"simple_2m",
"difficult_2m",
"simple_3m",
"difficult_3m",
"simple_5m",
"difficult_5m"
]
}
}
38 changes: 38 additions & 0 deletions pyfixest/estimation/internals/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,29 @@
crv1_meat_loop_cupy = crv1_meat_loop_nb
count_fixef_fully_nested_all_cupy = count_fixef_fully_nested_all_nb

# Try to import Torch functions, fall back to numba if not available
try:
from pyfixest.estimation.torch.demean_torch_ import (
demean_torch,
demean_torch_cpu,
demean_torch_cuda,
demean_torch_cuda32,
demean_torch_mps,
)

TORCH_AVAILABLE = True
except ImportError:
demean_torch = demean_nb
demean_torch_cpu = demean_nb
demean_torch_mps = demean_nb
demean_torch_cuda = demean_nb
demean_torch_cuda32 = demean_nb
TORCH_AVAILABLE = False

find_collinear_variables_torch = find_collinear_variables_nb
crv1_meat_loop_torch = crv1_meat_loop_nb
count_fixef_fully_nested_all_torch = count_fixef_fully_nested_all_nb

BACKENDS = {
"numba": {
"demean": demean_nb,
Expand Down Expand Up @@ -96,4 +119,19 @@
"crv1_meat": crv1_meat_loop_cupy,
"nonnested": count_fixef_fully_nested_all_cupy,
},
**{
name: {
"demean": demean_fn,
"collinear": find_collinear_variables_torch,
"crv1_meat": crv1_meat_loop_torch,
"nonnested": count_fixef_fully_nested_all_torch,
}
for name, demean_fn in [
("torch", demean_torch),
("torch_cpu", demean_torch_cpu),
("torch_mps", demean_torch_mps),
("torch_cuda", demean_torch_cuda),
("torch_cuda32", demean_torch_cuda32),
]
},
}
28 changes: 5 additions & 23 deletions pyfixest/estimation/internals/demean_.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,27 +346,9 @@ def _set_demeaner_backend(
ValueError
If the demeaning backend is not supported.
"""
if demeaner_backend == "rust":
from pyfixest.core.demean import demean as demean_rs
from pyfixest.estimation.internals.backends import BACKENDS

return demean_rs
elif demeaner_backend == "rust-cg":
from pyfixest.core.demean import demean_within

return demean_within
elif demeaner_backend == "numba":
return demean
elif demeaner_backend == "jax":
from pyfixest.estimation.jax.demean_jax_ import demean_jax

return demean_jax
elif demeaner_backend in ["cupy", "cupy64"]:
from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy64

return demean_cupy64
elif demeaner_backend == "cupy32":
from pyfixest.estimation.cupy.demean_cupy_ import demean_cupy32

return demean_cupy32
else:
raise ValueError(f"Invalid demeaner backend: {demeaner_backend}")
try:
return BACKENDS[demeaner_backend]["demean"]
except KeyError as exc:
raise ValueError(f"Invalid demeaner backend: {demeaner_backend}") from exc
14 changes: 13 additions & 1 deletion pyfixest/estimation/internals/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,19 @@
"jax",
]
DemeanerBackendOptions = Literal[
"numba", "jax", "rust", "rust-cg", "cupy", "cupy32", "cupy64", "scipy"
"numba",
"jax",
"rust",
"rust-cg",
"cupy",
"cupy32",
"cupy64",
"scipy",
"torch",
"torch_cpu",
"torch_mps",
"torch_cuda",
"torch_cuda32",
]
PredictionErrorOptions = Literal["prediction"]
QuantregMethodOptions = Literal["fn", "pfn"]
Expand Down
Loading