Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ dependencies:
- linearmodels
- unicodeit
- Faker
- narwhals

Binary file modified main.pdf
Binary file not shown.
6 changes: 3 additions & 3 deletions pyfixest_tables.tex
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
& (1) & (2)\\
\midrule
X1 & -0.919*** & -0.007 \\
& (0.066) & (0.035) \\
& (0.060) & (0.042) \\
X2 & & -0.015 \\
& & (0.010) \\
& & (0.011) \\
\midrule
Observations & 997 & 997 \\
Observations & 997 & 995 \\
$R^2$ & 0.609 & \\
\bottomrule
\multicolumn{3}{l}{{\small \textit{*p$<$0.1, **p$<$0.05, ***p$<$0.01}}}\\
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ test = [
"pytest-cov>=2.0",
"statsmodels",
"linearmodels",
"pyfixest",
"unicodeit",
"narwhals",
"pyfixest>=0.40.1",
"faker",
"polars",
"pyarrow",
]

dev = [
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@ numpy
pandas
scipy
unicodeit
Faker
narwhals
polars
pyarrow
20 changes: 11 additions & 9 deletions samplenotebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"type": "integer"
}
],
"ref": "1f4fb5a8-29a3-4f4a-b861-1b2cd019cfd6",
"ref": "dabae8ec-9941-4705-bacb-d4968ec32494",
"rows": [
[
"0",
Expand Down Expand Up @@ -3523,7 +3523,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n"
"OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n",
"/Users/andersonfrailey/opt/anaconda3/envs/statstables-dev/lib/python3.13/site-packages/pyfixest/estimation/model_matrix_fixest_.py:215: UserWarning: 2 singleton fixed effect(s) detected. These observations are dropped from the model.\n",
" warnings.warn(\n"
]
},
{
Expand Down Expand Up @@ -3553,9 +3555,9 @@
" <tr>\n",
" <td style=\"text-align: left;\"></td>\n",
"\n",
" <td style=\"text-align: center;\">(0.066)</td>\n",
" <td style=\"text-align: center;\">(0.060)</td>\n",
"\n",
" <td style=\"text-align: center;\">(0.035)</td>\n",
" <td style=\"text-align: center;\">(0.042)</td>\n",
"\n",
" </tr>\n",
" <tr>\n",
Expand All @@ -3571,7 +3573,7 @@
"\n",
" <td style=\"text-align: center;\"></td>\n",
"\n",
" <td style=\"text-align: center;\">(0.010)</td>\n",
" <td style=\"text-align: center;\">(0.011)</td>\n",
"\n",
" </tr>\n",
" <tr>\n",
Expand All @@ -3582,7 +3584,7 @@
"\n",
" <td style=\"text-align: center;\">997</td>\n",
"\n",
" <td style=\"text-align: center;\">997</td>\n",
" <td style=\"text-align: center;\">995</td>\n",
"\n",
" </tr>\n",
" <tr>\n",
Expand All @@ -3604,11 +3606,11 @@
"+ (1) (2) +\n",
"+------------------------------------------------+\n",
"+ X1 -0.919*** -0.007 +\n",
"+ (0.066) (0.035) +\n",
"+ (0.060) (0.042) +\n",
"+ X2 -0.015 +\n",
"+ (0.010) +\n",
"+ (0.011) +\n",
"--------------------------------------------------\n",
"+ Observations 997 997 +\n",
"+ Observations 997 995 +\n",
"+ R² 0.609 +\n",
"--------------------------------------------------\n",
"*p<0.1, **p<0.05, ***p<0.01 "
Expand Down
34 changes: 19 additions & 15 deletions statstables/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pandas as pd
import numpy as np
import statstables as st
import narwhals as nw
from narwhals.typing import IntoDataFrame
from abc import ABC, abstractmethod
from scipy import stats
from typing import Union, Callable, overload
Expand Down Expand Up @@ -773,11 +775,11 @@ class GenericTable(Table):
column/index naming
"""

def __init__(self, df: pd.DataFrame | pd.Series, **kwargs):
self.df = df
self.ncolumns = df.shape[1]
self.columns = df.columns
self.nrows = df.shape[0]
def __init__(self, df: IntoDataFrame, **kwargs):
self.df = nw.from_native(df).to_pandas()
self.ncolumns = self.df.shape[1]
self.columns = self.df.columns
self.nrows = self.df.shape[0]
super().__init__(**kwargs)

def reset_params(self, restore_to_defaults=False):
Expand Down Expand Up @@ -809,7 +811,7 @@ def _create_rows(self) -> list[list[ChainMap]]:
class MeanDifferenceTable(Table):
def __init__(
self,
df: pd.DataFrame,
df: IntoDataFrame,
var_list: list,
group_var: str,
diff_pairs: list[tuple] | None = None,
Expand Down Expand Up @@ -842,7 +844,7 @@ def __init__(

Parameters
----------
df : pd.DataFrame
df : IntoDataFrame
DataFrame containing the raw data to be compared
var_list : list
List of variables to compare means to between the groups
Expand Down Expand Up @@ -875,7 +877,8 @@ def __init__(
}
self.table_params = MeanDiffsTableParams(user_params)
# TODO: allow for grouping on multiple variables
self.groups = df[group_var].unique()
self.df = nw.from_native(df).to_pandas()
self.groups = self.df[group_var].unique()
self.ngroups = len(self.groups)
self.var_list = var_list
if self.ngroups > 2 and not diff_pairs:
Expand All @@ -885,14 +888,14 @@ def __init__(
if self.ngroups < 2:
raise ValueError("There must be at least two groups")
self.alternative = alternative
self.type_gdf = df.groupby(group_var)
self.type_gdf = self.df.groupby(group_var)
# adjust these to only count non-null values
self.grp_sizes = self.type_gdf.size()
self.grp_sizes["Overall Mean"] = df.shape[0]
self.grp_sizes["Overall Mean"] = self.df.shape[0]
self.means = self.type_gdf[var_list].mean().T
# add toal means column to means
self.means["Overall Mean"] = df[var_list].mean()
total_sem = df[var_list].sem()
self.means["Overall Mean"] = self.df[var_list].mean()
total_sem = self.df[var_list].sem()
assert isinstance(total_sem, pd.Series)
total_sem.name = "Overall Mean"
self.sem = pd.merge(
Expand Down Expand Up @@ -1106,10 +1109,11 @@ def _create_rows(self) -> list[list[ChainMap]]:


class SummaryTable(GenericTable):
def __init__(self, df: pd.DataFrame, var_list: list[str] | None = None, **kwargs):
def __init__(self, df: IntoDataFrame, var_list: list[str] | None = None, **kwargs):
self.df = nw.from_native(df).to_pandas()
if var_list is None:
var_list = list(df.columns)
summary_df = df[var_list].describe()
var_list = list(self.df.columns)
summary_df = self.df[var_list].describe()
super().__init__(summary_df, **kwargs)

def reset_custom_features(self):
Expand Down
9 changes: 9 additions & 0 deletions statstables/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import pytest
import pandas as pd
import polars as pl
import numpy as np
import statsmodels.formula.api as smf
import pyfixest as pf
Expand Down Expand Up @@ -52,6 +53,14 @@ def test_generic_table(data):
table = tables.GenericTable(df)
print(table)

# test with polars dataframe to test narwhals implementation
pl_data = pl.from_pandas(data)
table = tables.GenericTable(df=pl_data)

table.render_ascii()
table.render_html()
table.render_latex()


def test_summary_table(data):
table = tables.SummaryTable(df=data, var_list=["A", "B", "C"])
Expand Down
Loading