High-performance econometric estimation using PyTorch with first-class GPU support and automatic differentiation. Implements method of moments estimators (GMM, GEL), maximum likelihood models, and discrete choice models with modern deep learning workflows.
- Linear Regression: OLS with fixed effects via GPU-accelerated alternating projections
- Generalized Method of Moments (GMM): Two-step efficient GMM with HAC-robust inference
- Generalized Empirical Likelihood (GEL): Empirical likelihood, exponential tilting, and CUE estimators
- Maximum Likelihood: Logistic and Poisson regression with PyTorch optimizers
- Discrete Choice Models: Multinomial logit, probit, and low-rank logit for large choice sets
- Automatic GPU Detection: Seamless CPU/GPU operation with device management
- Heteroskedasticity-Robust Inference: HC0-HC3 standard errors for cross-sectional data
- HAC-Robust Inference: Newey-West covariance estimation for time series
- Custom Optimizers: Full access to PyTorch optimizer ecosystem (LBFGS, Adam, SGD)
- Batched Operations: Memory-efficient estimation for large datasets
- M-Series Mac Support: Native MPS backend support (no JAX Metal issues)
git clone https://github.com/apoorvalal/trex
cd trex
uv venv
source .venv/bin/activate
uv syncGenerate API docs from docstrings with pdoc:
uv sync --extra docs
bash docs/build_api_docs.shThe rendered site is written to docs/api/ by default.
import torch
from trex import LinearRegression
# Panel data: firms × years
n_firms, n_years = 100, 10
n_obs = n_firms * n_years
X = torch.randn(n_obs, 3)
firm_ids = torch.repeat_interleave(torch.arange(n_firms), n_years)
year_ids = torch.tile(torch.arange(n_years), (n_firms,))
# True DGP with fixed effects
true_coef = torch.tensor([1.5, -0.8, 0.3])
firm_effects = torch.randn(n_firms)[firm_ids]
year_effects = torch.randn(n_years)[year_ids]
y = X @ true_coef + firm_effects + year_effects + 0.1 * torch.randn(n_obs)
# Two-way fixed effects regression
model = LinearRegression()
model.fit(X, y, fe=[firm_ids, year_ids], se="HC1")
print(f"Coefficients: {model.params['coef']}")
print(f"Robust SE: {model.params['se']}")from trex.gmm import GMMEstimator
# Define IV moment condition: E[Z'(Y - X'β)] = 0
def iv_moment(Z, Y, X, beta):
return Z * (Y - X @ beta).unsqueeze(-1)
# Two-step efficient GMM
gmm = GMMEstimator(iv_moment, weighting_matrix="optimal", backend="torch")
gmm.fit(instruments, outcome, endogenous_vars, two_step=True)
print(gmm.summary())from trex import LogisticRegression
# Binary response model
X = torch.randn(1000, 5)
true_coef = torch.tensor([0.5, 1.0, -0.8, 0.3, 0.2])
y = torch.bernoulli(torch.sigmoid(X @ true_coef))
# MLE with Fisher information-based standard errors
model = LogisticRegression(maxiter=100)
model.fit(X, y)
model.summary() # Displays coefficients, SE, z-stats, p-values
# Predictions
probs = model.predict_proba(X)
classes = model.predict(X, threshold=0.5)import torch
from trex import LogisticRegression, PoissonRegression
n_firms, n_years = 50, 12
n_obs = n_firms * n_years
X = torch.randn(n_obs, 2)
firm_ids = torch.repeat_interleave(torch.arange(n_firms), n_years)
year_ids = torch.tile(torch.arange(n_years), (n_firms,))
offset = 0.1 * torch.randn(n_obs)
# Fixed-effects logit
logit = LogisticRegression(maxiter=100)
logit.fit(
X,
y_binary,
fe=[firm_ids, year_ids],
offset=offset,
hdfe_index=0,
)
print(logit.params["coef"])
print(logit.params["se"])
print(logit.params["fe_se_diag"]) # diagonal SEs for the hdfe block
# Fixed-effects Poisson
poisson = PoissonRegression(maxiter=100)
poisson.fit(
X,
y_count,
fe=[firm_ids],
)
print(poisson.params["coef"])
print(poisson.params["fe_coef"][0])Sparse FE incidence matrices are also supported through fe_design=[csr_block, ...]
when you already have one-hot FE structures in CSR or COO format.
from trex.choice import LowRankLogit
# Large-scale choice data with varying assortments
n_users, n_items, rank = 1000, 100, 5
user_indices = torch.randint(0, n_users, (5000,))
chosen_items = torch.randint(0, n_items, (5000,))
assortments = torch.randint(0, 2, (5000, n_items)).float() # Binary availability
# Factorized utility model: Θ = AB' with zero-sum normalization
model = LowRankLogit(rank=rank, n_users=n_users, n_items=n_items, lam=0.01)
model.fit(user_indices, chosen_items, assortments)
# Counterfactual analysis
baseline = torch.ones(100, n_items)
baseline[:, 50] = 0 # Product 50 unavailable
counterfactual = torch.ones(100, n_items) # All products available
results = model.counterfactual(user_indices[:100], baseline, counterfactual)
print(f"Market share change: {results['market_share_change'][50]:.3f}")All estimators automatically detect and use CUDA/MPS when available:
# Automatic device detection
model = LinearRegression() # Uses CUDA if available, else CPU
model.fit(X, y)
# Explicit device control
model_cpu = LinearRegression(device='cpu')
model_gpu = LogisticRegression(device='cuda')
# Move fitted models between devices
model.fit(X_cpu, y_cpu)
model.to('cuda') # Transfer to GPU
predictions = model.predict(X_gpu) # Input data auto-moved to model deviceThe library implements Hansen's (1982) GMM framework. Given moment conditions
where
- First step:
$W_1 = I$ (identity matrix) - Second step:
$W_2 = \hat{\Omega}^{-1}$ where$\hat{\Omega} = \frac{1}{n}\sum_i g_i g_i'$
Asymptotic distribution with optimal weighting:
For time series or spatial dependence, the library implements Newey-West (1987) HAC covariance:
with Bartlett kernel weights
Multi-way fixed effects are eliminated via alternating projections (Gaure, 2013). For two-way effects:
where $\ddot{z}{it} = z{it} - \bar{z}{i\cdot} - \bar{z}{\cdot t} + \bar{z}_{\cdot\cdot}$ is the within transformation.
For nonlinear FE models such as logit and Poisson, trex estimates
the fixed effects directly rather than applying a within transformation. Those
estimators can be useful for panel GLMs, but they remain subject to incidental
parameter bias in short panels.
The low-rank logit model (Kallus & Udell, 2016) factorizes user-item utilities as
See mathematical notes for detailed exposition.
import torch.optim as optim
# Use Adam instead of default LBFGS
model = LogisticRegression(
optimizer=optim.Adam,
maxiter=1000
)# Linear regression solvers
model_torch = LinearRegression(solver="torch") # PyTorch lstsq (GPU-capable)
model_numpy = LinearRegression(solver="numpy") # NumPy fallbackFor datasets exceeding GPU memory, use DataLoader for batched optimization:
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
# Custom training loop with gradient accumulation
# See notebooks for complete examples- GPU Acceleration: 10-100× speedup for large datasets (n > 10,000)
- Fixed Effects: Memory-efficient alternating projections scale to millions of observations
- Batched Operations: Vectorized computations throughout, compatible with
torch.compile - Numerical Stability: Eigenvalue regularization and pseudo-inverse for ill-conditioned problems
trex is a PyTorch port of jaxonometrics with enhanced device management:
| Feature | jaxonometrics | trex |
|---|---|---|
| Backend | JAX | PyTorch |
| M-Series Mac | Metal issues | Native MPS support |
| GPU Support | CUDA/TPU | CUDA/MPS/CPU |
| Auto-diff | jax.grad |
torch.autograd |
| Compilation | jax.jit |
torch.compile |
| Device Management | Manual | Automatic with .to() |
Contributions welcome. See CONTRIBUTING.md for guidelines.
@software{trex,
title = {trex: GPU-accelerated econometrics in PyTorch},
author = {Lal, Apoorva},
year = {2025},
url = {https://github.com/apoorvalal/trex}
}- Hansen, L. P. (1982). Large sample properties of generalized method of moments estimators. Econometrica, 50(4), 1029-1054.
- Newey, W. K., & West, K. D. (1987). A simple, positive semi-definite, heteroskedasticity and autocorrelation consistent covariance matrix. Econometrica, 55(3), 703-708.
- Gaure, S. (2013). OLS with multiple high dimensional category variables. Computational Statistics & Data Analysis, 66, 8-18.
- Kallus, N., & Udell, M. (2016). Dynamic assortment personalization in high dimensions. arXiv preprint arXiv:1610.05604.
- jaxonometrics - JAX-based econometrics (parent project)
- pyfixest - Fast fixed effects estimation
- linearmodels - Panel data models
MIT License. See LICENSE for details.