diff --git a/environment.yml b/environment.yml
index 40b1fdc..2753e14 100644
--- a/environment.yml
+++ b/environment.yml
@@ -14,4 +14,5 @@ dependencies:
- linearmodels
- unicodeit
- Faker
+ - narwhals
\ No newline at end of file
diff --git a/main.pdf b/main.pdf
index df3530c..d008494 100644
Binary files a/main.pdf and b/main.pdf differ
diff --git a/pyfixest_tables.tex b/pyfixest_tables.tex
index 8ffd3e0..97ab1d3 100644
--- a/pyfixest_tables.tex
+++ b/pyfixest_tables.tex
@@ -11,11 +11,11 @@
& (1) & (2)\\
\midrule
X1 & -0.919*** & -0.007 \\
- & (0.066) & (0.035) \\
+ & (0.060) & (0.042) \\
X2 & & -0.015 \\
- & & (0.010) \\
+ & & (0.011) \\
\midrule
- Observations & 997 & 997 \\
+ Observations & 997 & 995 \\
$R^2$ & 0.609 & \\
\bottomrule
\multicolumn{3}{l}{{\small \textit{*p$<$0.1, **p$<$0.05, ***p$<$0.01}}}\\
diff --git a/pyproject.toml b/pyproject.toml
index 96b182a..b201b09 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,8 +44,12 @@ test = [
"pytest-cov>=2.0",
"statsmodels",
"linearmodels",
- "pyfixest",
+ "unicodeit",
+ "narwhals",
+ "pyfixest>=0.40.1",
"faker",
+ "polars",
+ "pyarrow",
]
dev = [
diff --git a/requirements.txt b/requirements.txt
index 08a9e16..463a6a1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,3 +2,7 @@ numpy
pandas
scipy
unicodeit
+Faker
+narwhals
+polars
+pyarrow
\ No newline at end of file
diff --git a/samplenotebook.ipynb b/samplenotebook.ipynb
index 0d437c7..0f83c2e 100644
--- a/samplenotebook.ipynb
+++ b/samplenotebook.ipynb
@@ -69,7 +69,7 @@
"type": "integer"
}
],
- "ref": "1f4fb5a8-29a3-4f4a-b861-1b2cd019cfd6",
+ "ref": "dabae8ec-9941-4705-bacb-d4968ec32494",
"rows": [
[
"0",
@@ -3523,7 +3523,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n"
+ "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n",
+ "/Users/andersonfrailey/opt/anaconda3/envs/statstables-dev/lib/python3.13/site-packages/pyfixest/estimation/model_matrix_fixest_.py:215: UserWarning: 2 singleton fixed effect(s) detected. These observations are dropped from the model.\n",
+ " warnings.warn(\n"
]
},
{
@@ -3553,9 +3555,9 @@
"
\n",
" | \n",
"\n",
- " (0.066) | \n",
+ " (0.060) | \n",
"\n",
- " (0.035) | \n",
+ " (0.042) | \n",
"\n",
"
\n",
" \n",
@@ -3571,7 +3573,7 @@
"\n",
" | \n",
"\n",
- " (0.010) | \n",
+ " (0.011) | \n",
"\n",
"
\n",
" \n",
@@ -3582,7 +3584,7 @@
"\n",
" | 997 | \n",
"\n",
- " 997 | \n",
+ " 995 | \n",
"\n",
"
\n",
" \n",
@@ -3604,11 +3606,11 @@
"+ (1) (2) +\n",
"+------------------------------------------------+\n",
"+ X1 -0.919*** -0.007 +\n",
- "+ (0.066) (0.035) +\n",
+ "+ (0.060) (0.042) +\n",
"+ X2 -0.015 +\n",
- "+ (0.010) +\n",
+ "+ (0.011) +\n",
"--------------------------------------------------\n",
- "+ Observations 997 997 +\n",
+ "+ Observations 997 995 +\n",
"+ R² 0.609 +\n",
"--------------------------------------------------\n",
"*p<0.1, **p<0.05, ***p<0.01 "
diff --git a/statstables/tables.py b/statstables/tables.py
index 85e0d61..4a5a396 100644
--- a/statstables/tables.py
+++ b/statstables/tables.py
@@ -3,6 +3,8 @@
import pandas as pd
import numpy as np
import statstables as st
+import narwhals as nw
+from narwhals.typing import IntoDataFrame
from abc import ABC, abstractmethod
from scipy import stats
from typing import Union, Callable, overload
@@ -773,11 +775,11 @@ class GenericTable(Table):
column/index naming
"""
- def __init__(self, df: pd.DataFrame | pd.Series, **kwargs):
- self.df = df
- self.ncolumns = df.shape[1]
- self.columns = df.columns
- self.nrows = df.shape[0]
+ def __init__(self, df: IntoDataFrame, **kwargs):
+ self.df = nw.from_native(df).to_pandas()
+ self.ncolumns = self.df.shape[1]
+ self.columns = self.df.columns
+ self.nrows = self.df.shape[0]
super().__init__(**kwargs)
def reset_params(self, restore_to_defaults=False):
@@ -809,7 +811,7 @@ def _create_rows(self) -> list[list[ChainMap]]:
class MeanDifferenceTable(Table):
def __init__(
self,
- df: pd.DataFrame,
+ df: IntoDataFrame,
var_list: list,
group_var: str,
diff_pairs: list[tuple] | None = None,
@@ -842,7 +844,7 @@ def __init__(
Parameters
----------
- df : pd.DataFrame
+ df : IntoDataFrame
DataFrame containing the raw data to be compared
var_list : list
List of variables to compare means to between the groups
@@ -875,7 +877,8 @@ def __init__(
}
self.table_params = MeanDiffsTableParams(user_params)
# TODO: allow for grouping on multiple variables
- self.groups = df[group_var].unique()
+ self.df = nw.from_native(df).to_pandas()
+ self.groups = self.df[group_var].unique()
self.ngroups = len(self.groups)
self.var_list = var_list
if self.ngroups > 2 and not diff_pairs:
@@ -885,14 +888,14 @@ def __init__(
if self.ngroups < 2:
raise ValueError("There must be at least two groups")
self.alternative = alternative
- self.type_gdf = df.groupby(group_var)
+ self.type_gdf = self.df.groupby(group_var)
# adjust these to only count non-null values
self.grp_sizes = self.type_gdf.size()
- self.grp_sizes["Overall Mean"] = df.shape[0]
+ self.grp_sizes["Overall Mean"] = self.df.shape[0]
self.means = self.type_gdf[var_list].mean().T
# add toal means column to means
- self.means["Overall Mean"] = df[var_list].mean()
- total_sem = df[var_list].sem()
+ self.means["Overall Mean"] = self.df[var_list].mean()
+ total_sem = self.df[var_list].sem()
assert isinstance(total_sem, pd.Series)
total_sem.name = "Overall Mean"
self.sem = pd.merge(
@@ -1106,10 +1109,11 @@ def _create_rows(self) -> list[list[ChainMap]]:
class SummaryTable(GenericTable):
- def __init__(self, df: pd.DataFrame, var_list: list[str] | None = None, **kwargs):
+ def __init__(self, df: IntoDataFrame, var_list: list[str] | None = None, **kwargs):
+ self.df = nw.from_native(df).to_pandas()
if var_list is None:
- var_list = list(df.columns)
- summary_df = df[var_list].describe()
+ var_list = list(self.df.columns)
+ summary_df = self.df[var_list].describe()
super().__init__(summary_df, **kwargs)
def reset_custom_features(self):
diff --git a/statstables/tests/test_tables.py b/statstables/tests/test_tables.py
index 07e7146..58203f6 100644
--- a/statstables/tests/test_tables.py
+++ b/statstables/tests/test_tables.py
@@ -5,6 +5,7 @@
import copy
import pytest
import pandas as pd
+import polars as pl
import numpy as np
import statsmodels.formula.api as smf
import pyfixest as pf
@@ -52,6 +53,14 @@ def test_generic_table(data):
table = tables.GenericTable(df)
print(table)
+ # test with polars dataframe to test narwhals implementation
+ pl_data = pl.from_pandas(data)
+ table = tables.GenericTable(df=pl_data)
+
+ table.render_ascii()
+ table.render_html()
+ table.render_latex()
+
def test_summary_table(data):
table = tables.SummaryTable(df=data, var_list=["A", "B", "C"])